wren93 commited on
Commit
414f9e8
·
1 Parent(s): defcb9b

update demo

Browse files
Files changed (1) hide show
  1. app.py +180 -158
app.py CHANGED
@@ -33,171 +33,194 @@ css = """
33
  }
34
  """
35
 
36
- class AnimateController:
37
- def __init__(self):
38
-
39
- # config dirs
40
- self.basedir = os.getcwd()
41
- self.savedir = os.path.join(self.basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
42
- self.savedir_sample = os.path.join(self.savedir, "sample")
43
- os.makedirs(self.savedir, exist_ok=True)
44
-
45
- self.image_resolution = (256, 256)
46
- # config models
47
- self.pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V")
48
- self.pipeline.to("cuda")
49
-
50
- def update_textbox_and_save_image(self, input_image, height_slider, width_slider, center_crop):
51
- pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
52
- img_path = os.path.join(self.savedir, "input_image.png")
53
- pil_image.save(img_path)
54
- self.image_resolution = pil_image.size
55
-
56
- original_width, original_height = pil_image.size
57
- if center_crop:
58
- crop_aspect_ratio = width_slider / height_slider
59
- aspect_ratio = original_width / original_height
60
- if aspect_ratio > crop_aspect_ratio:
61
- new_width = int(crop_aspect_ratio * original_height)
62
- left = (original_width - new_width) / 2
63
- top = 0
64
- right = left + new_width
65
- bottom = original_height
66
- pil_image = pil_image.crop((left, top, right, bottom))
67
- elif aspect_ratio < crop_aspect_ratio:
68
- new_height = int(original_width / crop_aspect_ratio)
69
- top = (original_height - new_height) / 2
70
- left = 0
71
- right = original_width
72
- bottom = top + new_height
73
- pil_image = pil_image.crop((left, top, right, bottom))
74
-
75
- pil_image = pil_image.resize((width_slider, height_slider))
76
- return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
77
-
78
- @spaces.GPU
79
- def animate(
80
- self,
81
- prompt_textbox,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  negative_prompt_textbox,
83
- input_image_path,
84
- sampler_dropdown,
85
- sample_step_slider,
86
- width_slider,
87
- height_slider,
88
  txt_cfg_scale_slider,
89
  img_cfg_scale_slider,
90
- center_crop,
91
  frame_stride,
92
  use_frameinit,
93
  frame_init_noise_level,
94
- seed_textbox
95
- ):
96
- if self.pipeline is None:
97
- raise gr.Error(f"Please select a pretrained pipeline path.")
98
- if input_image_path == "":
99
- raise gr.Error(f"Please upload an input image.")
100
- if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0):
101
- raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.")
102
- if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0):
103
- raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.")
104
-
105
- if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: self.pipeline.unet.enable_xformers_memory_efficient_attention()
106
-
107
- if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
108
- else: torch.seed()
109
- seed = torch.initial_seed()
110
-
111
- if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
112
- first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
113
- else:
114
- first_frame = Image.open(input_image_path).convert('RGB')
115
-
116
- original_width, original_height = first_frame.size
117
-
118
- if not center_crop:
119
- img_transform = T.Compose([
120
- T.ToTensor(),
121
- T.Resize((height_slider, width_slider), antialias=None),
122
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
123
- ])
124
- else:
125
- aspect_ratio = original_width / original_height
126
- crop_aspect_ratio = width_slider / height_slider
127
- if aspect_ratio > crop_aspect_ratio:
128
- center_crop_width = int(crop_aspect_ratio * original_height)
129
- center_crop_height = original_height
130
- elif aspect_ratio < crop_aspect_ratio:
131
- center_crop_width = original_width
132
- center_crop_height = int(original_width / crop_aspect_ratio)
133
- else:
134
- center_crop_width = original_width
135
- center_crop_height = original_height
136
- img_transform = T.Compose([
137
- T.ToTensor(),
138
- T.CenterCrop((center_crop_height, center_crop_width)),
139
- T.Resize((height_slider, width_slider), antialias=None),
140
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
141
- ])
142
-
143
- first_frame = img_transform(first_frame).unsqueeze(0)
144
- first_frame = first_frame.to("cuda")
145
- print("first_frame", first_frame.device)
146
-
147
- if use_frameinit:
148
- self.pipeline.init_filter(
149
- width = width_slider,
150
- height = height_slider,
151
- video_length = 16,
152
- filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,})
153
- )
154
 
 
 
 
 
155
 
156
- sample = self.pipeline(
157
- prompt_textbox,
158
- negative_prompt = negative_prompt_textbox,
159
- first_frames = first_frame,
160
- num_inference_steps = sample_step_slider,
161
- guidance_scale_txt = txt_cfg_scale_slider,
162
- guidance_scale_img = img_cfg_scale_slider,
163
- width = width_slider,
164
- height = height_slider,
165
- video_length = 16,
166
- noise_sampling_method = "pyoco_mixed",
167
- noise_alpha = 1.0,
168
- frame_stride = frame_stride,
169
- use_frameinit = use_frameinit,
170
- frameinit_noise_level = frame_init_noise_level,
171
- camera_motion = None,
172
- ).videos
173
-
174
- global sample_idx
175
- sample_idx += 1
176
- save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
177
- save_videos_grid(sample, save_sample_path, format="mp4")
178
-
179
- sample_config = {
180
- "prompt": prompt_textbox,
181
- "n_prompt": negative_prompt_textbox,
182
- "first_frame_path": input_image_path,
183
- "sampler": sampler_dropdown,
184
- "num_inference_steps": sample_step_slider,
185
- "guidance_scale_text": txt_cfg_scale_slider,
186
- "guidance_scale_image": img_cfg_scale_slider,
187
- "width": width_slider,
188
- "height": height_slider,
189
- "video_length": 8,
190
- "seed": seed
191
- }
192
- json_str = json.dumps(sample_config, indent=4)
193
- with open(os.path.join(self.savedir, "logs.json"), "a") as f:
194
- f.write(json_str)
195
- f.write("\n\n")
196
-
197
- return gr.Video(value=save_sample_path)
198
 
 
 
199
 
200
- controller = AnimateController()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
 
203
  def ui():
@@ -257,7 +280,7 @@ def ui():
257
 
258
  with gr.Row():
259
  input_image = gr.Image(label="Input Image", interactive=True)
260
- input_image.upload(fn=controller.update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image])
261
  result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
262
 
263
  def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop):
@@ -265,7 +288,6 @@ def ui():
265
  pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
266
  else:
267
  pil_image = Image.open(input_image_path).convert('RGB')
268
- controller.image_resolution = pil_image.size
269
  original_width, original_height = pil_image.size
270
 
271
  if center_crop:
@@ -293,7 +315,7 @@ def ui():
293
  input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
294
 
295
  generate_button.click(
296
- fn=controller.animate,
297
  inputs=[
298
  prompt_textbox,
299
  negative_prompt_textbox,
 
33
  }
34
  """
35
 
36
+
37
+ basedir = os.getcwd()
38
+ savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
39
+ savedir_sample = os.path.join(savedir, "sample")
40
+ os.makedirs(savedir, exist_ok=True)
41
+
42
+ # config models
43
+ pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V", torch_dtype=torch.float16,)
44
+ pipeline.to("cuda")
45
+ # pipeline.to("cuda")
46
+
47
+ def update_textbox_and_save_image(input_image, height_slider, width_slider, center_crop):
48
+ pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
49
+ img_path = os.path.join(savedir, "input_image.png")
50
+ pil_image.save(img_path)
51
+
52
+ original_width, original_height = pil_image.size
53
+ if center_crop:
54
+ crop_aspect_ratio = width_slider / height_slider
55
+ aspect_ratio = original_width / original_height
56
+ if aspect_ratio > crop_aspect_ratio:
57
+ new_width = int(crop_aspect_ratio * original_height)
58
+ left = (original_width - new_width) / 2
59
+ top = 0
60
+ right = left + new_width
61
+ bottom = original_height
62
+ pil_image = pil_image.crop((left, top, right, bottom))
63
+ elif aspect_ratio < crop_aspect_ratio:
64
+ new_height = int(original_width / crop_aspect_ratio)
65
+ top = (original_height - new_height) / 2
66
+ left = 0
67
+ right = original_width
68
+ bottom = top + new_height
69
+ pil_image = pil_image.crop((left, top, right, bottom))
70
+
71
+ pil_image = pil_image.resize((width_slider, height_slider))
72
+ return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
73
+
74
+
75
+ def animate(
76
+ prompt_textbox,
77
+ negative_prompt_textbox,
78
+ input_image_path,
79
+ sampler_dropdown,
80
+ sample_step_slider,
81
+ width_slider,
82
+ height_slider,
83
+ txt_cfg_scale_slider,
84
+ img_cfg_scale_slider,
85
+ center_crop,
86
+ frame_stride,
87
+ use_frameinit,
88
+ frame_init_noise_level,
89
+ seed_textbox
90
+ ):
91
+ if pipeline is None:
92
+ raise gr.Error(f"Please select a pretrained pipeline path.")
93
+ if input_image_path == "":
94
+ raise gr.Error(f"Please upload an input image.")
95
+ if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0):
96
+ raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.")
97
+ if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0):
98
+ raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.")
99
+
100
+ if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: pipeline.unet.enable_xformers_memory_efficient_attention()
101
+
102
+ if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
103
+ else: torch.seed()
104
+ seed = torch.initial_seed()
105
+
106
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
107
+ first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
108
+ else:
109
+ first_frame = Image.open(input_image_path).convert('RGB')
110
+
111
+ original_width, original_height = first_frame.size
112
+
113
+ if not center_crop:
114
+ img_transform = T.Compose([
115
+ T.ToTensor(),
116
+ T.Resize((height_slider, width_slider), antialias=None),
117
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
118
+ ])
119
+ else:
120
+ aspect_ratio = original_width / original_height
121
+ crop_aspect_ratio = width_slider / height_slider
122
+ if aspect_ratio > crop_aspect_ratio:
123
+ center_crop_width = int(crop_aspect_ratio * original_height)
124
+ center_crop_height = original_height
125
+ elif aspect_ratio < crop_aspect_ratio:
126
+ center_crop_width = original_width
127
+ center_crop_height = int(original_width / crop_aspect_ratio)
128
+ else:
129
+ center_crop_width = original_width
130
+ center_crop_height = original_height
131
+ img_transform = T.Compose([
132
+ T.ToTensor(),
133
+ T.CenterCrop((center_crop_height, center_crop_width)),
134
+ T.Resize((height_slider, width_slider), antialias=None),
135
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
136
+ ])
137
+
138
+ first_frame = img_transform(first_frame).unsqueeze(0)
139
+
140
+ if use_frameinit:
141
+ pipeline.init_filter(
142
+ width = width_slider,
143
+ height = height_slider,
144
+ video_length = 16,
145
+ filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,})
146
+ )
147
+
148
+ sample = run_pipeline(
149
+ pipeline,
150
+ prompt_textbox,
151
  negative_prompt_textbox,
152
+ first_frame,
153
+ sample_step_slider,
154
+ width_slider,
155
+ height_slider,
 
156
  txt_cfg_scale_slider,
157
  img_cfg_scale_slider,
 
158
  frame_stride,
159
  use_frameinit,
160
  frame_init_noise_level,
161
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ global sample_idx
164
+ sample_idx += 1
165
+ save_sample_path = os.path.join(savedir_sample, f"{sample_idx}.mp4")
166
+ save_videos_grid(sample, save_sample_path, format="mp4")
167
 
168
+ sample_config = {
169
+ "prompt": prompt_textbox,
170
+ "n_prompt": negative_prompt_textbox,
171
+ "first_frame_path": input_image_path,
172
+ "sampler": sampler_dropdown,
173
+ "num_inference_steps": sample_step_slider,
174
+ "guidance_scale_text": txt_cfg_scale_slider,
175
+ "guidance_scale_image": img_cfg_scale_slider,
176
+ "width": width_slider,
177
+ "height": height_slider,
178
+ "video_length": 8,
179
+ "seed": seed
180
+ }
181
+ json_str = json.dumps(sample_config, indent=4)
182
+ with open(os.path.join(savedir, "logs.json"), "a") as f:
183
+ f.write(json_str)
184
+ f.write("\n\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ return gr.Video(value=save_sample_path)
187
+
188
 
189
+ @spaces.GPU
190
+ def run_pipeline(
191
+ pipeline,
192
+ prompt_textbox,
193
+ negative_prompt_textbox,
194
+ first_frame,
195
+ sample_step_slider,
196
+ width_slider,
197
+ height_slider,
198
+ txt_cfg_scale_slider,
199
+ img_cfg_scale_slider,
200
+ frame_stride,
201
+ use_frameinit,
202
+ frame_init_noise_level,
203
+
204
+ ):
205
+ first_frame = first_frame.to("cuda")
206
+ sample = pipeline(
207
+ prompt_textbox,
208
+ negative_prompt = negative_prompt_textbox,
209
+ first_frames = first_frame,
210
+ num_inference_steps = sample_step_slider,
211
+ guidance_scale_txt = txt_cfg_scale_slider,
212
+ guidance_scale_img = img_cfg_scale_slider,
213
+ width = width_slider,
214
+ height = height_slider,
215
+ video_length = 16,
216
+ noise_sampling_method = "pyoco_mixed",
217
+ noise_alpha = 1.0,
218
+ frame_stride = frame_stride,
219
+ use_frameinit = use_frameinit,
220
+ frameinit_noise_level = frame_init_noise_level,
221
+ camera_motion = None,
222
+ ).videos
223
+ return sample
224
 
225
 
226
  def ui():
 
280
 
281
  with gr.Row():
282
  input_image = gr.Image(label="Input Image", interactive=True)
283
+ input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image])
284
  result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
285
 
286
  def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop):
 
288
  pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
289
  else:
290
  pil_image = Image.open(input_image_path).convert('RGB')
 
291
  original_width, original_height = pil_image.size
292
 
293
  if center_crop:
 
315
  input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
316
 
317
  generate_button.click(
318
+ fn=animate,
319
  inputs=[
320
  prompt_textbox,
321
  negative_prompt_textbox,