rangm commited on
Commit
85d5083
·
verified ·
1 Parent(s): 8e96caf

Update src/pipelines/pipeline_echo_mimic.py

Browse files
src/pipelines/pipeline_echo_mimic.py CHANGED
@@ -34,6 +34,7 @@ from transformers import CLIPImageProcessor
34
  from src.models.mutual_self_attention import ReferenceAttentionControl
35
  from src.pipelines.context import get_context_scheduler
36
  from src.pipelines.utils import get_tensor_interpolation_method
 
37
 
38
  @dataclass
39
  class Audio2VideoPipelineOutput(BaseOutput):
@@ -417,9 +418,9 @@ class Audio2VideoPipeline(DiffusionPipeline):
417
  generator
418
  )
419
  # print(video_length, latents.shape)
420
- face_locator_tensor = self.face_locator(face_mask_tensor)
421
- uc_face_locator_tensor = torch.zeros_like(face_locator_tensor)
422
- face_locator_tensor = torch.cat([uc_face_locator_tensor, face_locator_tensor], dim=0)
423
  # Prepare extra step kwargs.
424
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
425
 
@@ -474,7 +475,7 @@ class Audio2VideoPipeline(DiffusionPipeline):
474
  encoder_hidden_states=None,
475
  return_dict=False,
476
  )
477
- reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=True)
478
 
479
 
480
  num_context_batches = math.ceil(len(context_queue) / context_batch_size)
@@ -498,8 +499,8 @@ class Audio2VideoPipeline(DiffusionPipeline):
498
  .to(device)
499
  .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
500
  )
501
- audio_latents = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device)
502
- audio_latents = torch.cat([torch.zeros_like(audio_latents), audio_latents], 0)
503
 
504
  latent_model_input = self.scheduler.scale_model_input(
505
  latent_model_input, t
@@ -508,11 +509,15 @@ class Audio2VideoPipeline(DiffusionPipeline):
508
  latent_model_input,
509
  t,
510
  encoder_hidden_states=None,
511
- audio_cond_fea=audio_latents,
512
- face_musk_fea=face_locator_tensor,
513
  return_dict=False,
514
  )[0]
515
 
 
 
 
 
516
  for j, c in enumerate(new_context):
517
  noise_pred[:, :, c] = noise_pred[:, :, c] + pred
518
  counter[:, :, c] = counter[:, :, c] + 1
@@ -523,6 +528,8 @@ class Audio2VideoPipeline(DiffusionPipeline):
523
  noise_pred = noise_pred_uncond + guidance_scale * (
524
  noise_pred_text - noise_pred_uncond
525
  )
 
 
526
 
527
  latents = self.scheduler.step(
528
  noise_pred, t, latents, **extra_step_kwargs
@@ -583,4 +590,4 @@ class Audio2VideoPipeline(DiffusionPipeline):
583
  smoothed_tensor = torch.cat(
584
  [tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2)
585
 
586
- return smoothed_tensor
 
34
  from src.models.mutual_self_attention import ReferenceAttentionControl
35
  from src.pipelines.context import get_context_scheduler
36
  from src.pipelines.utils import get_tensor_interpolation_method
37
+ from src.utils.step_func import origin_by_velocity_and_sample, psuedo_velocity_wrt_noisy_and_timestep, get_alpha
38
 
39
  @dataclass
40
  class Audio2VideoPipelineOutput(BaseOutput):
 
418
  generator
419
  )
420
  # print(video_length, latents.shape)
421
+ c_face_locator_tensor = self.face_locator(face_mask_tensor)
422
+ uc_face_locator_tensor = torch.zeros_like(c_face_locator_tensor)
423
+ face_locator_tensor = torch.cat([uc_face_locator_tensor, c_face_locator_tensor], dim=0)
424
  # Prepare extra step kwargs.
425
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
426
 
 
475
  encoder_hidden_states=None,
476
  return_dict=False,
477
  )
478
+ reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=do_classifier_free_guidance)
479
 
480
 
481
  num_context_batches = math.ceil(len(context_queue) / context_batch_size)
 
499
  .to(device)
500
  .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
501
  )
502
+ c_audio_latents = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device)
503
+ audio_latents = torch.cat([torch.zeros_like(c_audio_latents), c_audio_latents], 0)
504
 
505
  latent_model_input = self.scheduler.scale_model_input(
506
  latent_model_input, t
 
509
  latent_model_input,
510
  t,
511
  encoder_hidden_states=None,
512
+ audio_cond_fea=audio_latents if do_classifier_free_guidance else c_audio_latents,
513
+ face_musk_fea=face_locator_tensor if do_classifier_free_guidance else c_face_locator_tensor,
514
  return_dict=False,
515
  )[0]
516
 
517
+ alphas_cumprod = self.scheduler.alphas_cumprod.to(latent_model_input.device)
518
+ x_pred = origin_by_velocity_and_sample(pred, latent_model_input, alphas_cumprod, t)
519
+ pred = psuedo_velocity_wrt_noisy_and_timestep(latent_model_input, x_pred, alphas_cumprod, t, torch.ones_like(t) * (-1))
520
+
521
  for j, c in enumerate(new_context):
522
  noise_pred[:, :, c] = noise_pred[:, :, c] + pred
523
  counter[:, :, c] = counter[:, :, c] + 1
 
528
  noise_pred = noise_pred_uncond + guidance_scale * (
529
  noise_pred_text - noise_pred_uncond
530
  )
531
+ else:
532
+ noise_pred = noise_pred / counter
533
 
534
  latents = self.scheduler.step(
535
  noise_pred, t, latents, **extra_step_kwargs
 
590
  smoothed_tensor = torch.cat(
591
  [tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2)
592
 
593
+ return smoothed_tensor