Delik commited on
Commit
959fd21
·
verified ·
1 Parent(s): d0d6aa5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -39
app.py CHANGED
@@ -28,12 +28,9 @@ import os
28
  import time
29
  import numpy as np
30
 
31
-
32
-
33
  # Disable Gradio analytics to avoid network-related issues
34
  gr.analytics_enabled = False
35
 
36
-
37
  def check_package_installed(package_name):
38
  package_spec = importlib.util.find_spec(package_name)
39
  if package_spec is None:
@@ -77,11 +74,11 @@ def main(args):
77
  audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
78
  predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
79
  predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
80
-
81
  #======Loading Stage 1 model=========
82
  lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
83
  lia.load_lightning_model(args.stage1_checkpoint_path)
84
- lia.to(args.device)
85
  #============================
86
 
87
  conf = ffhq256_autoenc()
@@ -122,7 +119,7 @@ def main(args):
122
  print(f'{args.test_audio_path} does not exist!')
123
  exit(0)
124
 
125
- img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
126
  one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
127
 
128
  #======Loading Stage 2 model=========
@@ -130,7 +127,7 @@ def main(args):
130
  state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
131
  model.load_state_dict(state, strict=True)
132
  model.ema_model.eval()
133
- model.ema_model.to(args.device)
134
  #=================================
135
 
136
  #======Audio Input=========
@@ -144,7 +141,7 @@ def main(args):
144
  frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
145
  audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
146
 
147
- audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
148
 
149
  elif conf.infer_type.startswith('hubert'):
150
  # Hubert features
@@ -163,7 +160,7 @@ def main(args):
163
 
164
  # load hubert model
165
  from transformers import Wav2Vec2FeatureExtractor, HubertModel
166
- audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
167
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
168
  audio_model.feature_extractor._freeze_parameters()
169
  audio_model.eval()
@@ -171,7 +168,7 @@ def main(args):
171
  # hubert model forward pass
172
  audio, sr = librosa.load(args.test_audio_path, sr=16000)
173
  input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
174
- input_values = input_values.to(args.device)
175
  ws_feats = []
176
  with torch.no_grad():
177
  outputs = audio_model(input_values, output_hidden_states=True)
@@ -192,11 +189,11 @@ def main(args):
192
  frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
193
  audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
194
 
195
- audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
196
  #============================
197
 
198
  # Diffusion Noise
199
- noisyT = torch.randn((1,frame_end, args.motion_dim)).to(args.device)
200
 
201
  #======Inputs for Attribute Control=========
202
  if os.path.exists(args.pose_driven_path):
@@ -215,17 +212,17 @@ def main(args):
215
  padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
216
  pose_obj = np.vstack((pose_obj, padding))
217
 
218
- pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
219
  else:
220
- yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
221
- pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
222
- roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
223
  pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
224
 
225
  pose_signal = torch.clamp(pose_signal, -1, 1)
226
 
227
- face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
228
- face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
229
  #===========================================
230
 
231
  start_time = time.time()
@@ -242,7 +239,7 @@ def main(args):
242
  start_time = time.time()
243
  #======Rendering images frame-by-frame=========
244
  for pred_index in tqdm(range(generated_directions.shape[1])):
245
- ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
246
  ori_img_recon = ori_img_recon.clamp(-1, 1)
247
  wav_pred = (ori_img_recon.detach() + 1) / 2
248
  saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
@@ -276,8 +273,9 @@ def main(args):
276
  else:
277
  return predicted_video_256_path, predicted_video_256_path
278
 
 
279
  def generate_video(uploaded_img, uploaded_audio, infer_type,
280
- pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed):
281
  if uploaded_img is None or uploaded_audio is None:
282
  return None, gr.Markdown("Error: Input image or audio file is empty. Please check and upload both files.")
283
 
@@ -289,14 +287,6 @@ def generate_video(uploaded_img, uploaded_audio, infer_type,
289
  "hubert_full_control": "ckpt/stage2_full_control_hubert.ckpt",
290
  }
291
 
292
- # if face_crop:
293
- # uploaded_img_path = Path(uploaded_img)
294
- # cropped_img_path = uploaded_img_path.with_name(uploaded_img_path.stem + "_crop" + uploaded_img_path.suffix)
295
- # crop_image(uploaded_img, cropped_img_path)
296
- # uploaded_img = str(cropped_img_path)
297
-
298
- # import pdb;pdb.set_trace()
299
-
300
  stage2_checkpoint_path = model_mapping.get(infer_type, "default_checkpoint.ckpt")
301
  try:
302
  args = argparse.Namespace(
@@ -317,19 +307,14 @@ def generate_video(uploaded_img, uploaded_audio, infer_type,
317
  face_scale=face_scale,
318
  step_T=step_T,
319
  image_size=256,
320
- device=device,
321
  motion_dim=20,
322
  decoder_layers=2,
323
  face_sr=face_sr
324
  )
325
 
326
- # Save the uploaded audio to the expected path
327
- # shutil.copy(uploaded_audio, args.test_audio_path)
328
-
329
- # Run the main function
330
  output_256_video_path, output_512_video_path = main(args)
331
 
332
- # Check if the output video file exists
333
  if not os.path.exists(output_256_video_path):
334
  return None, gr.Markdown("Error: Video generation failed. Please check your inputs and try again.")
335
  if output_256_video_path == output_512_video_path:
@@ -347,7 +332,6 @@ default_values = {
347
  "face_scale": 0.5,
348
  "step_T": 50,
349
  "seed": 0,
350
- "device": "cuda"
351
  }
352
 
353
  with gr.Blocks() as demo:
@@ -373,8 +357,6 @@ with gr.Blocks() as demo:
373
  value='hubert_audio_only'
374
  )
375
  face_sr = gr.Checkbox(label="Enable Face Super-Resolution (512*512)", value=False)
376
- # face_crop = gr.Checkbox(label="Face Crop (Dlib)", value=False)
377
- # face_crop = False # TODO
378
  seed = gr.Number(label="Seed", value=default_values["seed"])
379
  pose_yaw = gr.Slider(label="pose_yaw", minimum=-1, maximum=1, value=default_values["pose_yaw"])
380
  pose_pitch = gr.Slider(label="pose_pitch", minimum=-1, maximum=1, value=default_values["pose_pitch"])
@@ -382,14 +364,13 @@ with gr.Blocks() as demo:
382
  face_location = gr.Slider(label="face_location", minimum=0, maximum=1, value=default_values["face_location"])
383
  face_scale = gr.Slider(label="face_scale", minimum=0, maximum=1, value=default_values["face_scale"])
384
  step_T = gr.Slider(label="step_T", minimum=1, maximum=100, step=1, value=default_values["step_T"])
385
- device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"])
386
 
387
 
388
  generate_button.click(
389
  generate_video,
390
  inputs=[
391
  uploaded_img, uploaded_audio, infer_type,
392
- pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed
393
  ],
394
  outputs=[output_video_256, output_video_512, output_message]
395
  )
 
28
  import time
29
  import numpy as np
30
 
 
 
31
  # Disable Gradio analytics to avoid network-related issues
32
  gr.analytics_enabled = False
33
 
 
34
  def check_package_installed(package_name):
35
  package_spec = importlib.util.find_spec(package_name)
36
  if package_spec is None:
 
74
  audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
75
  predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
76
  predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
77
+
78
  #======Loading Stage 1 model=========
79
  lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
80
  lia.load_lightning_model(args.stage1_checkpoint_path)
81
+ lia.to('cuda')
82
  #============================
83
 
84
  conf = ffhq256_autoenc()
 
119
  print(f'{args.test_audio_path} does not exist!')
120
  exit(0)
121
 
122
+ img_source = img_preprocessing(args.test_image_path, args.image_size).to('cuda')
123
  one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
124
 
125
  #======Loading Stage 2 model=========
 
127
  state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
128
  model.load_state_dict(state, strict=True)
129
  model.ema_model.eval()
130
+ model.ema_model.to('cuda')
131
  #=================================
132
 
133
  #======Audio Input=========
 
141
  frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
142
  audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
143
 
144
+ audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to('cuda')
145
 
146
  elif conf.infer_type.startswith('hubert'):
147
  # Hubert features
 
160
 
161
  # load hubert model
162
  from transformers import Wav2Vec2FeatureExtractor, HubertModel
163
+ audio_model = HubertModel.from_pretrained(hubert_model_path).to('cuda')
164
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
165
  audio_model.feature_extractor._freeze_parameters()
166
  audio_model.eval()
 
168
  # hubert model forward pass
169
  audio, sr = librosa.load(args.test_audio_path, sr=16000)
170
  input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
171
+ input_values = input_values.to('cuda')
172
  ws_feats = []
173
  with torch.no_grad():
174
  outputs = audio_model(input_values, output_hidden_states=True)
 
189
  frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
190
  audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
191
 
192
+ audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to('cuda')
193
  #============================
194
 
195
  # Diffusion Noise
196
+ noisyT = torch.randn((1,frame_end, args.motion_dim)).to('cuda')
197
 
198
  #======Inputs for Attribute Control=========
199
  if os.path.exists(args.pose_driven_path):
 
212
  padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
213
  pose_obj = np.vstack((pose_obj, padding))
214
 
215
+ pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to('cuda') / 90 # 90 is for normalization here
216
  else:
217
+ yaw_signal = torch.zeros(1, frame_end, 1).to('cuda') + args.pose_yaw
218
+ pitch_signal = torch.zeros(1, frame_end, 1).to('cuda') + args.pose_pitch
219
+ roll_signal = torch.zeros(1, frame_end, 1).to('cuda') + args.pose_roll
220
  pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
221
 
222
  pose_signal = torch.clamp(pose_signal, -1, 1)
223
 
224
+ face_location_signal = torch.zeros(1, frame_end, 1).to('cuda') + args.face_location
225
+ face_scae_signal = torch.zeros(1, frame_end, 1).to('cuda') + args.face_scale
226
  #===========================================
227
 
228
  start_time = time.time()
 
239
  start_time = time.time()
240
  #======Rendering images frame-by-frame=========
241
  for pred_index in tqdm(range(generated_directions.shape[1])):
242
+ ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to('cuda'), feats)
243
  ori_img_recon = ori_img_recon.clamp(-1, 1)
244
  wav_pred = (ori_img_recon.detach() + 1) / 2
245
  saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
 
273
  else:
274
  return predicted_video_256_path, predicted_video_256_path
275
 
276
+ @spaces.GPU
277
  def generate_video(uploaded_img, uploaded_audio, infer_type,
278
+ pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, face_sr, seed):
279
  if uploaded_img is None or uploaded_audio is None:
280
  return None, gr.Markdown("Error: Input image or audio file is empty. Please check and upload both files.")
281
 
 
287
  "hubert_full_control": "ckpt/stage2_full_control_hubert.ckpt",
288
  }
289
 
 
 
 
 
 
 
 
 
290
  stage2_checkpoint_path = model_mapping.get(infer_type, "default_checkpoint.ckpt")
291
  try:
292
  args = argparse.Namespace(
 
307
  face_scale=face_scale,
308
  step_T=step_T,
309
  image_size=256,
310
+ device='cuda',
311
  motion_dim=20,
312
  decoder_layers=2,
313
  face_sr=face_sr
314
  )
315
 
 
 
 
 
316
  output_256_video_path, output_512_video_path = main(args)
317
 
 
318
  if not os.path.exists(output_256_video_path):
319
  return None, gr.Markdown("Error: Video generation failed. Please check your inputs and try again.")
320
  if output_256_video_path == output_512_video_path:
 
332
  "face_scale": 0.5,
333
  "step_T": 50,
334
  "seed": 0,
 
335
  }
336
 
337
  with gr.Blocks() as demo:
 
357
  value='hubert_audio_only'
358
  )
359
  face_sr = gr.Checkbox(label="Enable Face Super-Resolution (512*512)", value=False)
 
 
360
  seed = gr.Number(label="Seed", value=default_values["seed"])
361
  pose_yaw = gr.Slider(label="pose_yaw", minimum=-1, maximum=1, value=default_values["pose_yaw"])
362
  pose_pitch = gr.Slider(label="pose_pitch", minimum=-1, maximum=1, value=default_values["pose_pitch"])
 
364
  face_location = gr.Slider(label="face_location", minimum=0, maximum=1, value=default_values["face_location"])
365
  face_scale = gr.Slider(label="face_scale", minimum=0, maximum=1, value=default_values["face_scale"])
366
  step_T = gr.Slider(label="step_T", minimum=1, maximum=100, step=1, value=default_values["step_T"])
 
367
 
368
 
369
  generate_button.click(
370
  generate_video,
371
  inputs=[
372
  uploaded_img, uploaded_audio, infer_type,
373
+ pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, face_sr, seed
374
  ],
375
  outputs=[output_video_256, output_video_512, output_message]
376
  )