ruizhaocv commited on
Commit
7d421db
·
1 Parent(s): 886cbb8

Upload 17 files

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MotionDirector_inference.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import re
5
+ import warnings
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
10
+ from einops import rearrange
11
+ from torch import Tensor
12
+ from torch.nn.functional import interpolate
13
+ from tqdm import trange
14
+ import random
15
+
16
+ from MotionDirector_train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
17
+ from utils.lora_handler import LoraHandler
18
+ from utils.ddim_utils import ddim_inversion
19
+ import imageio
20
+
21
+
22
+ def initialize_pipeline(
23
+ model: str,
24
+ device: str = "cuda",
25
+ xformers: bool = False,
26
+ sdp: bool = False,
27
+ lora_path: str = "",
28
+ lora_rank: int = 64,
29
+ lora_scale: float = 1.0,
30
+ ):
31
+ with warnings.catch_warnings():
32
+ warnings.simplefilter("ignore")
33
+
34
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
35
+
36
+ # Freeze any necessary models
37
+ freeze_models([vae, text_encoder, unet])
38
+
39
+ # Enable xformers if available
40
+ handle_memory_attention(xformers, sdp, unet)
41
+
42
+ lora_manager_temporal = LoraHandler(
43
+ version="cloneofsimo",
44
+ use_unet_lora=True,
45
+ use_text_lora=False,
46
+ save_for_webui=False,
47
+ only_for_webui=False,
48
+ unet_replace_modules=["TransformerTemporalModel"],
49
+ text_encoder_replace_modules=None,
50
+ lora_bias=None
51
+ )
52
+
53
+ unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
54
+ True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
55
+
56
+ unet.eval()
57
+ text_encoder.eval()
58
+ unet_and_text_g_c(unet, text_encoder, False, False)
59
+
60
+ pipe = TextToVideoSDPipeline.from_pretrained(
61
+ pretrained_model_name_or_path=model,
62
+ scheduler=scheduler,
63
+ tokenizer=tokenizer,
64
+ text_encoder=text_encoder.to(device=device, dtype=torch.half),
65
+ vae=vae.to(device=device, dtype=torch.half),
66
+ unet=unet.to(device=device, dtype=torch.half),
67
+ )
68
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
69
+
70
+ return pipe
71
+
72
+
73
+ def inverse_video(pipe, latents, num_steps):
74
+ ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
75
+ ddim_inv_scheduler.set_timesteps(num_steps)
76
+
77
+ ddim_inv_latent = ddim_inversion(
78
+ pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
79
+ num_inv_steps=num_steps, prompt="")[-1]
80
+ return ddim_inv_latent
81
+
82
+
83
+ def prepare_input_latents(
84
+ pipe: TextToVideoSDPipeline,
85
+ batch_size: int,
86
+ num_frames: int,
87
+ height: int,
88
+ width: int,
89
+ latents_path:str,
90
+ noise_prior: float
91
+ ):
92
+ # initialize with random gaussian noise
93
+ scale = pipe.vae_scale_factor
94
+ shape = (batch_size, pipe.unet.config.in_channels, num_frames, height // scale, width // scale)
95
+ if noise_prior > 0.:
96
+ cached_latents = torch.load(latents_path)
97
+ if 'inversion_noise' not in cached_latents:
98
+ latents = inverse_video(pipe, cached_latents['latents'].unsqueeze(0), 50).squeeze(0)
99
+ else:
100
+ latents = torch.load(latents_path)['inversion_noise'].unsqueeze(0)
101
+ if latents.shape[0] != batch_size:
102
+ latents = latents.repeat(batch_size, 1, 1, 1, 1)
103
+ if latents.shape != shape:
104
+ latents = interpolate(rearrange(latents, "b c f h w -> (b f) c h w", b=batch_size), (height // scale, width // scale), mode='bilinear')
105
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", b=batch_size)
106
+ noise = torch.randn_like(latents, dtype=torch.half)
107
+ latents = (noise_prior) ** 0.5 * latents + (1 - noise_prior) ** 0.5 * noise
108
+ else:
109
+ latents = torch.randn(shape, dtype=torch.half)
110
+
111
+ return latents
112
+
113
+
114
+ def encode(pipe: TextToVideoSDPipeline, pixels: Tensor, batch_size: int = 8):
115
+ nf = pixels.shape[2]
116
+ pixels = rearrange(pixels, "b c f h w -> (b f) c h w")
117
+
118
+ latents = []
119
+ for idx in trange(
120
+ 0, pixels.shape[0], batch_size, desc="Encoding to latents...", unit_scale=batch_size, unit="frame"
121
+ ):
122
+ pixels_batch = pixels[idx : idx + batch_size].to(pipe.device, dtype=torch.half)
123
+ latents_batch = pipe.vae.encode(pixels_batch).latent_dist.sample()
124
+ latents_batch = latents_batch.mul(pipe.vae.config.scaling_factor).cpu()
125
+ latents.append(latents_batch)
126
+ latents = torch.cat(latents)
127
+
128
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=nf)
129
+
130
+ return latents
131
+
132
+
133
+ @torch.inference_mode()
134
+ def inference(
135
+ model: str,
136
+ prompt: str,
137
+ negative_prompt: Optional[str] = None,
138
+ width: int = 256,
139
+ height: int = 256,
140
+ num_frames: int = 24,
141
+ num_steps: int = 50,
142
+ guidance_scale: float = 15,
143
+ device: str = "cuda",
144
+ xformers: bool = False,
145
+ sdp: bool = False,
146
+ lora_path: str = "",
147
+ lora_rank: int = 64,
148
+ lora_scale: float = 1.0,
149
+ seed: Optional[int] = None,
150
+ latents_path: str="",
151
+ noise_prior: float = 0.,
152
+ repeat_num: int = 1,
153
+ ):
154
+ if seed is not None:
155
+ random_seed = seed
156
+ torch.manual_seed(seed)
157
+
158
+ with torch.autocast(device, dtype=torch.half):
159
+ # prepare models
160
+ pipe = initialize_pipeline(model, device, xformers, sdp, lora_path, lora_rank, lora_scale)
161
+
162
+ for i in range(repeat_num):
163
+ if seed is None:
164
+ random_seed = random.randint(100, 10000000)
165
+ torch.manual_seed(random_seed)
166
+
167
+ # prepare input latents
168
+ init_latents = prepare_input_latents(
169
+ pipe=pipe,
170
+ batch_size=len(prompt),
171
+ num_frames=num_frames,
172
+ height=height,
173
+ width=width,
174
+ latents_path=latents_path,
175
+ noise_prior=noise_prior
176
+ )
177
+
178
+ with torch.no_grad():
179
+ video_frames = pipe(
180
+ prompt=prompt,
181
+ negative_prompt=negative_prompt,
182
+ width=width,
183
+ height=height,
184
+ num_frames=num_frames,
185
+ num_inference_steps=num_steps,
186
+ guidance_scale=guidance_scale,
187
+ latents=init_latents
188
+ ).frames
189
+
190
+ # =========================================
191
+ # ========= write outputs to file =========
192
+ # =========================================
193
+ os.makedirs(args.output_dir, exist_ok=True)
194
+
195
+ # save to mp4
196
+ export_to_video(video_frames, f"{out_name}_{random_seed}.mp4", args.fps)
197
+
198
+ # # save to gif
199
+ file_name = f"{out_name}_{random_seed}.gif"
200
+ imageio.mimsave(file_name, video_frames, 'GIF', duration=1000 * 1 / args.fps, loop=0)
201
+
202
+ return video_frames
203
+
204
+
205
+ if __name__ == "__main__":
206
+ import decord
207
+
208
+ decord.bridge.set_bridge("torch")
209
+
210
+ # fmt: off
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument("-m", "--model", type=str, required=True,
213
+ help="HuggingFace repository or path to model checkpoint directory")
214
+ parser.add_argument("-p", "--prompt", type=str, required=True, help="Text prompt to condition on")
215
+ parser.add_argument("-n", "--negative-prompt", type=str, default=None, help="Text prompt to condition against")
216
+ parser.add_argument("-o", "--output_dir", type=str, default="./outputs/inference", help="Directory to save output video to")
217
+ parser.add_argument("-B", "--batch-size", type=int, default=1, help="Batch size for inference")
218
+ parser.add_argument("-W", "--width", type=int, default=384, help="Width of output video")
219
+ parser.add_argument("-H", "--height", type=int, default=384, help="Height of output video")
220
+ parser.add_argument("-T", "--num-frames", type=int, default=16, help="Total number of frames to generate")
221
+ parser.add_argument("-s", "--num-steps", type=int, default=30, help="Number of diffusion steps to run per frame.")
222
+ parser.add_argument("-g", "--guidance-scale", type=float, default=12, help="Scale for guidance loss (higher values = more guidance, but possibly more artifacts).")
223
+ parser.add_argument("-f", "--fps", type=int, default=8, help="FPS of output video")
224
+ parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to run inference on (defaults to cuda).")
225
+ parser.add_argument("-x", "--xformers", action="store_true", help="Use XFormers attnetion, a memory-efficient attention implementation (requires `pip install xformers`).")
226
+ parser.add_argument("-S", "--sdp", action="store_true", help="Use SDP attention, PyTorch's built-in memory-efficient attention implementation.")
227
+ parser.add_argument("-cf", "--checkpoint_folder", type=str, default=None, help="Path to Low Rank Adaptation checkpoint file (defaults to empty string, which uses no LoRA).")
228
+ parser.add_argument("-lr", "--lora_rank", type=int, default=32, help="Size of the LoRA checkpoint's projection matrix (defaults to 32).")
229
+ parser.add_argument("-ls", "--lora_scale", type=float, default=1.0, help="Scale of LoRAs.")
230
+ parser.add_argument("-r", "--seed", type=int, default=None, help="Random seed to make generations reproducible.")
231
+ parser.add_argument("-np", "--noise_prior", type=float, default=0., help="Scale of the influence of inversion noise.")
232
+ parser.add_argument("-ci", "--checkpoint_index", type=int, required=True,
233
+ help="The index of checkpoint, such as 300.")
234
+ parser.add_argument("-rn", "--repeat_num", type=int, default=1,
235
+ help="How many results to generate with the same prompt.")
236
+
237
+ args = parser.parse_args()
238
+ # fmt: on
239
+
240
+ # =========================================
241
+ # ====== validate and prepare inputs ======
242
+ # =========================================
243
+
244
+ out_name = f"{args.output_dir}/"
245
+ prompt = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", args.prompt) if platform.system() == "Windows" else args.prompt
246
+ out_name += f"{prompt}".replace(' ','_').replace(',', '').replace('.', '')
247
+
248
+ args.prompt = [prompt] * args.batch_size
249
+ if args.negative_prompt is not None:
250
+ args.negative_prompt = [args.negative_prompt] * args.batch_size
251
+
252
+ # =========================================
253
+ # ============= sample videos =============
254
+ # =========================================
255
+ if args.checkpoint_index is not None:
256
+ lora_path = f"{args.checkpoint_folder}/checkpoint-{args.checkpoint_index}/temporal/lora"
257
+ else:
258
+ lora_path = f"{args.checkpoint_folder}/checkpoint-default/temporal/lora"
259
+ latents_folder = f"{args.checkpoint_folder}/cached_latents"
260
+ latents_path = f"{latents_folder}/{random.choice(os.listdir(latents_folder))}"
261
+ assert os.path.exists(lora_path)
262
+ video_frames = inference(
263
+ model=args.model,
264
+ prompt=args.prompt,
265
+ negative_prompt=args.negative_prompt,
266
+ width=args.width,
267
+ height=args.height,
268
+ num_frames=args.num_frames,
269
+ num_steps=args.num_steps,
270
+ guidance_scale=args.guidance_scale,
271
+ device=args.device,
272
+ xformers=args.xformers,
273
+ sdp=args.sdp,
274
+ lora_path=lora_path,
275
+ lora_rank=args.lora_rank,
276
+ lora_scale = args.lora_scale,
277
+ seed=args.seed,
278
+ latents_path=latents_path,
279
+ noise_prior=args.noise_prior,
280
+ repeat_num=args.repeat_num
281
+ )
282
+
283
+
284
+
MotionDirector_inference_batch.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import re
5
+ import warnings
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
10
+ from einops import rearrange
11
+ from torch import Tensor
12
+ from torch.nn.functional import interpolate
13
+ from tqdm import trange
14
+ import random
15
+
16
+ from MotionDirector_train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
17
+ from utils.lora_handler import LoraHandler
18
+ from utils.ddim_utils import ddim_inversion
19
+ import imageio
20
+
21
+
22
+ def initialize_pipeline(
23
+ model: str,
24
+ device: str = "cuda",
25
+ xformers: bool = False,
26
+ sdp: bool = False,
27
+ lora_path: str = "",
28
+ lora_rank: int = 64,
29
+ lora_scale: float = 1.0,
30
+ ):
31
+ with warnings.catch_warnings():
32
+ warnings.simplefilter("ignore")
33
+
34
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
35
+
36
+ # Freeze any necessary models
37
+ freeze_models([vae, text_encoder, unet])
38
+
39
+ # Enable xformers if available
40
+ handle_memory_attention(xformers, sdp, unet)
41
+
42
+ lora_manager_temporal = LoraHandler(
43
+ version="cloneofsimo",
44
+ use_unet_lora=True,
45
+ use_text_lora=False,
46
+ save_for_webui=False,
47
+ only_for_webui=False,
48
+ unet_replace_modules=["TransformerTemporalModel"],
49
+ text_encoder_replace_modules=None,
50
+ lora_bias=None
51
+ )
52
+
53
+ unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
54
+ True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
55
+
56
+ unet.eval()
57
+ text_encoder.eval()
58
+ unet_and_text_g_c(unet, text_encoder, False, False)
59
+
60
+ pipe = TextToVideoSDPipeline.from_pretrained(
61
+ pretrained_model_name_or_path=model,
62
+ scheduler=scheduler,
63
+ tokenizer=tokenizer,
64
+ text_encoder=text_encoder.to(device=device, dtype=torch.half),
65
+ vae=vae.to(device=device, dtype=torch.half),
66
+ unet=unet.to(device=device, dtype=torch.half),
67
+ )
68
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
69
+
70
+ return pipe
71
+
72
+
73
+ def inverse_video(pipe, latents, num_steps):
74
+ ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
75
+ ddim_inv_scheduler.set_timesteps(num_steps)
76
+
77
+ ddim_inv_latent = ddim_inversion(
78
+ pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
79
+ num_inv_steps=num_steps, prompt="")[-1]
80
+ return ddim_inv_latent
81
+
82
+
83
+ def prepare_input_latents(
84
+ pipe: TextToVideoSDPipeline,
85
+ batch_size: int,
86
+ num_frames: int,
87
+ height: int,
88
+ width: int,
89
+ latents_path:str,
90
+ noise_prior: float
91
+ ):
92
+ # initialize with random gaussian noise
93
+ scale = pipe.vae_scale_factor
94
+ shape = (batch_size, pipe.unet.config.in_channels, num_frames, height // scale, width // scale)
95
+ if noise_prior > 0.:
96
+ cached_latents = torch.load(latents_path)
97
+ if 'inversion_noise' not in cached_latents:
98
+ latents = inverse_video(pipe, cached_latents['latents'].unsqueeze(0), 50).squeeze(0)
99
+ else:
100
+ latents = torch.load(latents_path)['inversion_noise'].unsqueeze(0)
101
+ if latents.shape[0] != batch_size:
102
+ latents = latents.repeat(batch_size, 1, 1, 1, 1)
103
+ if latents.shape != shape:
104
+ latents = interpolate(rearrange(latents, "b c f h w -> (b f) c h w", b=batch_size), (height // scale, width // scale), mode='bilinear')
105
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", b=batch_size)
106
+ noise = torch.randn_like(latents, dtype=torch.half)
107
+ latents = (noise_prior) ** 0.5 * latents + (1 - noise_prior) ** 0.5 * noise
108
+ else:
109
+ latents = torch.randn(shape, dtype=torch.half)
110
+
111
+ return latents
112
+
113
+
114
+ def encode(pipe: TextToVideoSDPipeline, pixels: Tensor, batch_size: int = 8):
115
+ nf = pixels.shape[2]
116
+ pixels = rearrange(pixels, "b c f h w -> (b f) c h w")
117
+
118
+ latents = []
119
+ for idx in trange(
120
+ 0, pixels.shape[0], batch_size, desc="Encoding to latents...", unit_scale=batch_size, unit="frame"
121
+ ):
122
+ pixels_batch = pixels[idx : idx + batch_size].to(pipe.device, dtype=torch.half)
123
+ latents_batch = pipe.vae.encode(pixels_batch).latent_dist.sample()
124
+ latents_batch = latents_batch.mul(pipe.vae.config.scaling_factor).cpu()
125
+ latents.append(latents_batch)
126
+ latents = torch.cat(latents)
127
+
128
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=nf)
129
+
130
+ return latents
131
+
132
+
133
+
134
+
135
+ @torch.inference_mode()
136
+ def inference(
137
+ model: str,
138
+ prompt: str,
139
+ negative_prompt: Optional[str] = None,
140
+ width: int = 256,
141
+ height: int = 256,
142
+ num_frames: int = 24,
143
+ num_steps: int = 50,
144
+ guidance_scale: float = 15,
145
+ device: str = "cuda",
146
+ xformers: bool = False,
147
+ sdp: bool = False,
148
+ lora_path: str = "",
149
+ lora_rank: int = 64,
150
+ lora_scale: float = 1.0,
151
+ seed: Optional[int] = None,
152
+ latents_path: str="",
153
+ noise_prior: float = 0.,
154
+ repeat_num: int = 1,
155
+ ):
156
+
157
+ with torch.autocast(device, dtype=torch.half):
158
+ # prepare models
159
+ pipe = initialize_pipeline(model, device, xformers, sdp, lora_path, lora_rank, lora_scale)
160
+
161
+ for i in range(repeat_num):
162
+ if seed is not None:
163
+ random_seed = seed
164
+ torch.manual_seed(seed)
165
+ else:
166
+ random_seed = random.randint(100, 10000000)
167
+ torch.manual_seed(random_seed)
168
+
169
+ # prepare input latents
170
+ init_latents = prepare_input_latents(
171
+ pipe=pipe,
172
+ batch_size=len(prompt),
173
+ num_frames=num_frames,
174
+ height=height,
175
+ width=width,
176
+ latents_path=latents_path,
177
+ noise_prior=noise_prior
178
+ )
179
+
180
+ video_frames = pipe(
181
+ prompt=prompt,
182
+ negative_prompt=negative_prompt,
183
+ width=width,
184
+ height=height,
185
+ num_frames=num_frames,
186
+ num_inference_steps=num_steps,
187
+ guidance_scale=guidance_scale,
188
+ latents=init_latents
189
+ ).frames
190
+ # =========================================
191
+ # ========= write outputs to file =========
192
+ # =========================================
193
+ os.makedirs(args.output_dir, exist_ok=True)
194
+
195
+ # save to mp4
196
+ export_to_video(video_frames, f"{out_name}_{random_seed}.mp4", args.fps)
197
+
198
+ # # save to gif
199
+ file_name = f"{out_name}_{random_seed}.gif"
200
+ imageio.mimsave(file_name, video_frames, 'GIF', duration=1000 * 1 / args.fps, loop=0)
201
+
202
+ return video_frames
203
+
204
+
205
+ if __name__ == "__main__":
206
+ import decord
207
+
208
+ decord.bridge.set_bridge("torch")
209
+
210
+ # fmt: off
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument("-m", "--model", type=str, default='/Users/rui/data/models/zeroscope_v2_576w/',
213
+ help="HuggingFace repository or path to model checkpoint directory")
214
+ parser.add_argument("-p", "--prompt", type=str, default=None, help="Text prompt to condition on")
215
+ parser.add_argument("-n", "--negative-prompt", type=str, default=None, help="Text prompt to condition against")
216
+ parser.add_argument("-o", "--output_dir", type=str, default="./outputs/inference", help="Directory to save output video to")
217
+ parser.add_argument("-B", "--batch-size", type=int, default=1, help="Batch size for inference")
218
+ parser.add_argument("-W", "--width", type=int, default=384, help="Width of output video")
219
+ parser.add_argument("-H", "--height", type=int, default=384, help="Height of output video")
220
+ parser.add_argument("-T", "--num-frames", type=int, default=16, help="Total number of frames to generate")
221
+ parser.add_argument("-s", "--num-steps", type=int, default=30, help="Number of diffusion steps to run per frame.")
222
+ parser.add_argument("-g", "--guidance-scale", type=float, default=12, help="Scale for guidance loss (higher values = more guidance, but possibly more artifacts).")
223
+ parser.add_argument("-f", "--fps", type=int, default=8, help="FPS of output video")
224
+ parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to run inference on (defaults to cuda).")
225
+ parser.add_argument("-x", "--xformers", action="store_true", help="Use XFormers attnetion, a memory-efficient attention implementation (requires `pip install xformers`).")
226
+ parser.add_argument("-S", "--sdp", action="store_true", help="Use SDP attention, PyTorch's built-in memory-efficient attention implementation.")
227
+ parser.add_argument("-cf", "--checkpoint_folder", type=str, default=None, help="Path to Low Rank Adaptation checkpoint file (defaults to empty string, which uses no LoRA).")
228
+ parser.add_argument("-lr", "--lora_rank", type=int, default=32, help="Size of the LoRA checkpoint's projection matrix (defaults to 32).")
229
+ parser.add_argument("-ls", "--lora_scale", type=float, default=1.0, help="Scale of LoRAs.")
230
+ parser.add_argument("-r", "--seed", type=int, default=None, help="Random seed to make generations reproducible.")
231
+ parser.add_argument("-np", "--noise_prior", type=float, default=0., help="Random seed to make generations reproducible.")
232
+ parser.add_argument("-ci", "--checkpoint_index", type=int, default=None,
233
+ help="Random seed to make generations reproducible.")
234
+ parser.add_argument("-rn", "--repeat_num", type=int, default=None,
235
+ help="Random seed to make generations reproducible.")
236
+
237
+ args = parser.parse_args()
238
+ # fmt: on
239
+
240
+ # =========================================
241
+ # ====== validate and prepare inputs ======
242
+ # =========================================
243
+
244
+ # args.prompt = ["A firefighter standing in front of a burning forest captured with a dolly zoom.",
245
+ # "A spaceman standing on the moon with earth behind him captured with a dolly zoom."]
246
+ args.prompt = "A person is riding a bicycle past the Eiffel Tower."
247
+ args.checkpoint_folder = './outputs/train/train_2023-12-02T11-45-22/'
248
+ args.checkpoint_index = 500
249
+ args.noise_prior = 0.
250
+ args.repeat_num = 10
251
+
252
+ out_name = f"{args.output_dir}/"
253
+ prompt = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", args.prompt) if platform.system() == "Windows" else args.prompt
254
+ out_name += f"{prompt}".replace(' ','_').replace(',', '').replace('.', '')
255
+
256
+ args.prompt = [prompt] * args.batch_size
257
+ if args.negative_prompt is not None:
258
+ args.negative_prompt = [args.negative_prompt] * args.batch_size
259
+
260
+ # =========================================
261
+ # ============= sample videos =============
262
+ # =========================================
263
+
264
+ lora_path = f"{args.checkpoint_folder}/checkpoint-{args.checkpoint_index}/temporal/lora"
265
+ latents_folder = f"{args.checkpoint_folder}/cached_latents"
266
+ latents_path = f"{latents_folder}/{random.choice(os.listdir(latents_folder))}"
267
+ # if args.seed is None:
268
+ # args.seed = random.randint(100, 10000000)
269
+ assert os.path.exists(lora_path)
270
+ video_frames = inference(
271
+ model=args.model,
272
+ prompt=args.prompt,
273
+ negative_prompt=args.negative_prompt,
274
+ width=args.width,
275
+ height=args.height,
276
+ num_frames=args.num_frames,
277
+ num_steps=args.num_steps,
278
+ guidance_scale=args.guidance_scale,
279
+ device=args.device,
280
+ xformers=args.xformers,
281
+ sdp=args.sdp,
282
+ lora_path=lora_path,
283
+ lora_rank=args.lora_rank,
284
+ lora_scale = args.lora_scale,
285
+ seed=args.seed,
286
+ latents_path=latents_path,
287
+ noise_prior=args.noise_prior,
288
+ repeat_num=args.repeat_num
289
+ )
290
+
MotionDirector_train.py ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ import random
8
+ import gc
9
+ import copy
10
+
11
+ from typing import Dict, Optional, Tuple
12
+ from omegaconf import OmegaConf
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint
17
+ import diffusers
18
+ import transformers
19
+
20
+ from torchvision import transforms
21
+ from tqdm.auto import tqdm
22
+
23
+ from accelerate import Accelerator
24
+ from accelerate.logging import get_logger
25
+ from accelerate.utils import set_seed
26
+
27
+ from models.unet_3d_condition import UNet3DConditionModel
28
+ from diffusers.models import AutoencoderKL
29
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
30
+ from diffusers.optimization import get_scheduler
31
+ from diffusers.utils.import_utils import is_xformers_available
32
+ from diffusers.models.attention_processor import AttnProcessor2_0, Attention
33
+ from diffusers.models.attention import BasicTransformerBlock
34
+
35
+ from transformers import CLIPTextModel, CLIPTokenizer
36
+ from transformers.models.clip.modeling_clip import CLIPEncoder
37
+ from utils.dataset import VideoJsonDataset, SingleVideoDataset, \
38
+ ImageDataset, VideoFolderDataset, CachedDataset
39
+ from einops import rearrange, repeat
40
+ from utils.lora_handler import LoraHandler
41
+ from utils.lora import extract_lora_child_module
42
+ from utils.ddim_utils import ddim_inversion
43
+ import imageio
44
+ import numpy as np
45
+
46
+
47
+ already_printed_trainables = False
48
+
49
+ logger = get_logger(__name__, log_level="INFO")
50
+
51
+
52
+ def create_logging(logging, logger, accelerator):
53
+ logging.basicConfig(
54
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
55
+ datefmt="%m/%d/%Y %H:%M:%S",
56
+ level=logging.INFO,
57
+ )
58
+ logger.info(accelerator.state, main_process_only=False)
59
+
60
+
61
+ def accelerate_set_verbose(accelerator):
62
+ if accelerator.is_local_main_process:
63
+ transformers.utils.logging.set_verbosity_warning()
64
+ diffusers.utils.logging.set_verbosity_info()
65
+ else:
66
+ transformers.utils.logging.set_verbosity_error()
67
+ diffusers.utils.logging.set_verbosity_error()
68
+
69
+
70
+ def get_train_dataset(dataset_types, train_data, tokenizer):
71
+ train_datasets = []
72
+
73
+ # Loop through all available datasets, get the name, then add to list of data to process.
74
+ for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
75
+ for dataset in dataset_types:
76
+ if dataset == DataSet.__getname__():
77
+ train_datasets.append(DataSet(**train_data, tokenizer=tokenizer))
78
+
79
+ if len(train_datasets) > 0:
80
+ return train_datasets
81
+ else:
82
+ raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")
83
+
84
+
85
+ def extend_datasets(datasets, dataset_items, extend=False):
86
+ biggest_data_len = max(x.__len__() for x in datasets)
87
+ extended = []
88
+ for dataset in datasets:
89
+ if dataset.__len__() == 0:
90
+ del dataset
91
+ continue
92
+ if dataset.__len__() < biggest_data_len:
93
+ for item in dataset_items:
94
+ if extend and item not in extended and hasattr(dataset, item):
95
+ print(f"Extending {item}")
96
+
97
+ value = getattr(dataset, item)
98
+ value *= biggest_data_len
99
+ value = value[:biggest_data_len]
100
+
101
+ setattr(dataset, item, value)
102
+
103
+ print(f"New {item} dataset length: {dataset.__len__()}")
104
+ extended.append(item)
105
+
106
+
107
+ def export_to_video(video_frames, output_video_path, fps):
108
+ video_writer = imageio.get_writer(output_video_path, fps=fps)
109
+ for img in video_frames:
110
+ video_writer.append_data(np.array(img))
111
+ video_writer.close()
112
+
113
+
114
+ def create_output_folders(output_dir, config):
115
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
116
+ out_dir = os.path.join(output_dir, f"train_{now}")
117
+
118
+ os.makedirs(out_dir, exist_ok=True)
119
+ os.makedirs(f"{out_dir}/samples", exist_ok=True)
120
+ # OmegaConf.save(config, os.path.join(out_dir, 'config.yaml'))
121
+
122
+ return out_dir
123
+
124
+
125
+ def load_primary_models(pretrained_model_path):
126
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
127
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
128
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
129
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
130
+ unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
131
+
132
+ return noise_scheduler, tokenizer, text_encoder, vae, unet
133
+
134
+
135
+ def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable):
136
+ unet._set_gradient_checkpointing(value=unet_enable)
137
+ text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable)
138
+
139
+
140
+ def freeze_models(models_to_freeze):
141
+ for model in models_to_freeze:
142
+ if model is not None: model.requires_grad_(False)
143
+
144
+
145
+ def is_attn(name):
146
+ return ('attn1' or 'attn2' == name.split('.')[-1])
147
+
148
+
149
+ def set_processors(attentions):
150
+ for attn in attentions: attn.set_processor(AttnProcessor2_0())
151
+
152
+
153
+ def set_torch_2_attn(unet):
154
+ optim_count = 0
155
+
156
+ for name, module in unet.named_modules():
157
+ if is_attn(name):
158
+ if isinstance(module, torch.nn.ModuleList):
159
+ for m in module:
160
+ if isinstance(m, BasicTransformerBlock):
161
+ set_processors([m.attn1, m.attn2])
162
+ optim_count += 1
163
+ if optim_count > 0:
164
+ print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
165
+
166
+
167
+ def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
168
+ try:
169
+ is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
170
+ enable_torch_2 = is_torch_2 and enable_torch_2_attn
171
+
172
+ if enable_xformers_memory_efficient_attention and not enable_torch_2:
173
+ if is_xformers_available():
174
+ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
175
+ unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
176
+ else:
177
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
178
+
179
+ if enable_torch_2:
180
+ set_torch_2_attn(unet)
181
+
182
+ except:
183
+ print("Could not enable memory efficient attention for xformers or Torch 2.0.")
184
+
185
+
186
+ def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
187
+ extra_params = extra_params if len(extra_params.keys()) > 0 else None
188
+ return {
189
+ "model": model,
190
+ "condition": condition,
191
+ 'extra_params': extra_params,
192
+ 'is_lora': is_lora,
193
+ "negation": negation
194
+ }
195
+
196
+
197
+ def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
198
+ params = {
199
+ "name": name,
200
+ "params": params,
201
+ "lr": lr
202
+ }
203
+ if extra_params is not None:
204
+ for k, v in extra_params.items():
205
+ params[k] = v
206
+
207
+ return params
208
+
209
+
210
+ def negate_params(name, negation):
211
+ # We have to do this if we are co-training with LoRA.
212
+ # This ensures that parameter groups aren't duplicated.
213
+ if negation is None: return False
214
+ for n in negation:
215
+ if n in name and 'temp' not in name:
216
+ return True
217
+ return False
218
+
219
+
220
+ def create_optimizer_params(model_list, lr):
221
+ import itertools
222
+ optimizer_params = []
223
+
224
+ for optim in model_list:
225
+ model, condition, extra_params, is_lora, negation = optim.values()
226
+ # Check if we are doing LoRA training.
227
+ if is_lora and condition and isinstance(model, list):
228
+ params = create_optim_params(
229
+ params=itertools.chain(*model),
230
+ extra_params=extra_params
231
+ )
232
+ optimizer_params.append(params)
233
+ continue
234
+
235
+ if is_lora and condition and not isinstance(model, list):
236
+ for n, p in model.named_parameters():
237
+ if 'lora' in n:
238
+ params = create_optim_params(n, p, lr, extra_params)
239
+ optimizer_params.append(params)
240
+ continue
241
+
242
+ # If this is true, we can train it.
243
+ if condition:
244
+ for n, p in model.named_parameters():
245
+ should_negate = 'lora' in n and not is_lora
246
+ if should_negate: continue
247
+
248
+ params = create_optim_params(n, p, lr, extra_params)
249
+ optimizer_params.append(params)
250
+
251
+ return optimizer_params
252
+
253
+
254
+ def get_optimizer(use_8bit_adam):
255
+ if use_8bit_adam:
256
+ try:
257
+ import bitsandbytes as bnb
258
+ except ImportError:
259
+ raise ImportError(
260
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
261
+ )
262
+
263
+ return bnb.optim.AdamW8bit
264
+ else:
265
+ return torch.optim.AdamW
266
+
267
+
268
+ def is_mixed_precision(accelerator):
269
+ weight_dtype = torch.float32
270
+
271
+ if accelerator.mixed_precision == "fp16":
272
+ weight_dtype = torch.float16
273
+
274
+ elif accelerator.mixed_precision == "bf16":
275
+ weight_dtype = torch.bfloat16
276
+
277
+ return weight_dtype
278
+
279
+
280
+ def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
281
+ for model in model_list:
282
+ if model is not None: model.to(accelerator.device, dtype=weight_dtype)
283
+
284
+
285
+ def inverse_video(pipe, latents, num_steps):
286
+ ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
287
+ ddim_inv_scheduler.set_timesteps(num_steps)
288
+
289
+ ddim_inv_latent = ddim_inversion(
290
+ pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
291
+ num_inv_steps=num_steps, prompt="")[-1]
292
+ return ddim_inv_latent
293
+
294
+
295
+ def handle_cache_latents(
296
+ should_cache,
297
+ output_dir,
298
+ train_dataloader,
299
+ train_batch_size,
300
+ vae,
301
+ unet,
302
+ pretrained_model_path,
303
+ noise_prior,
304
+ cached_latent_dir=None,
305
+ ):
306
+ # Cache latents by storing them in VRAM.
307
+ # Speeds up training and saves memory by not encoding during the train loop.
308
+ if not should_cache: return None
309
+ vae.to('cuda', dtype=torch.float16)
310
+ vae.enable_slicing()
311
+
312
+ pipe = TextToVideoSDPipeline.from_pretrained(
313
+ pretrained_model_path,
314
+ vae=vae,
315
+ unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16)
316
+ )
317
+ pipe.text_encoder.to('cuda', dtype=torch.float16)
318
+
319
+ cached_latent_dir = (
320
+ os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None
321
+ )
322
+
323
+ if cached_latent_dir is None:
324
+ cache_save_dir = f"{output_dir}/cached_latents"
325
+ os.makedirs(cache_save_dir, exist_ok=True)
326
+
327
+ for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):
328
+
329
+ save_name = f"cached_{i}"
330
+ full_out_path = f"{cache_save_dir}/{save_name}.pt"
331
+
332
+ pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16)
333
+ batch['latents'] = tensor_to_vae_latent(pixel_values, vae)
334
+ if noise_prior > 0.:
335
+ batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50)
336
+ for k, v in batch.items(): batch[k] = v[0]
337
+
338
+ torch.save(batch, full_out_path)
339
+ del pixel_values
340
+ del batch
341
+
342
+ # We do this to avoid fragmentation from casting latents between devices.
343
+ torch.cuda.empty_cache()
344
+ else:
345
+ cache_save_dir = cached_latent_dir
346
+
347
+ return torch.utils.data.DataLoader(
348
+ CachedDataset(cache_dir=cache_save_dir),
349
+ batch_size=train_batch_size,
350
+ shuffle=True,
351
+ num_workers=0
352
+ )
353
+
354
+
355
+ def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
356
+ global already_printed_trainables
357
+
358
+ # This can most definitely be refactored :-)
359
+ unfrozen_params = 0
360
+ if trainable_modules is not None:
361
+ for name, module in model.named_modules():
362
+ for tm in tuple(trainable_modules):
363
+ if tm == 'all':
364
+ model.requires_grad_(is_enabled)
365
+ unfrozen_params = len(list(model.parameters()))
366
+ break
367
+
368
+ if tm in name and 'lora' not in name:
369
+ for m in module.parameters():
370
+ m.requires_grad_(is_enabled)
371
+ if is_enabled: unfrozen_params += 1
372
+
373
+ if unfrozen_params > 0 and not already_printed_trainables:
374
+ already_printed_trainables = True
375
+ print(f"{unfrozen_params} params have been unfrozen for training.")
376
+
377
+
378
+ def tensor_to_vae_latent(t, vae):
379
+ video_length = t.shape[1]
380
+
381
+ t = rearrange(t, "b f c h w -> (b f) c h w")
382
+ latents = vae.encode(t).latent_dist.sample()
383
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
384
+ latents = latents * 0.18215
385
+
386
+ return latents
387
+
388
+
389
+ def sample_noise(latents, noise_strength, use_offset_noise=False):
390
+ b, c, f, *_ = latents.shape
391
+ noise_latents = torch.randn_like(latents, device=latents.device)
392
+
393
+ if use_offset_noise:
394
+ offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
395
+ noise_latents = noise_latents + noise_strength * offset_noise
396
+
397
+ return noise_latents
398
+
399
+
400
+ def enforce_zero_terminal_snr(betas):
401
+ """
402
+ Corrects noise in diffusion schedulers.
403
+ From: Common Diffusion Noise Schedules and Sample Steps are Flawed
404
+ https://arxiv.org/pdf/2305.08891.pdf
405
+ """
406
+ # Convert betas to alphas_bar_sqrt
407
+ alphas = 1 - betas
408
+ alphas_bar = alphas.cumprod(0)
409
+ alphas_bar_sqrt = alphas_bar.sqrt()
410
+
411
+ # Store old values.
412
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
413
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
414
+
415
+ # Shift so the last timestep is zero.
416
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
417
+
418
+ # Scale so the first timestep is back to the old value.
419
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
420
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T
421
+ )
422
+
423
+ # Convert alphas_bar_sqrt to betas
424
+ alphas_bar = alphas_bar_sqrt ** 2
425
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
426
+ alphas = torch.cat([alphas_bar[0:1], alphas])
427
+ betas = 1 - alphas
428
+
429
+ return betas
430
+
431
+
432
+ def should_sample(global_step, validation_steps, validation_data):
433
+ return global_step % validation_steps == 0 and validation_data.sample_preview
434
+
435
+
436
+ def save_pipe(
437
+ path,
438
+ global_step,
439
+ accelerator,
440
+ unet,
441
+ text_encoder,
442
+ vae,
443
+ output_dir,
444
+ lora_manager_spatial: LoraHandler,
445
+ lora_manager_temporal: LoraHandler,
446
+ unet_target_replace_module=None,
447
+ text_target_replace_module=None,
448
+ is_checkpoint=False,
449
+ save_pretrained_model=True
450
+ ):
451
+ if is_checkpoint:
452
+ save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
453
+ os.makedirs(save_path, exist_ok=True)
454
+ else:
455
+ save_path = output_dir
456
+
457
+ # Save the dtypes so we can continue training at the same precision.
458
+ u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype
459
+
460
+ # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
461
+ unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False))
462
+ text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False))
463
+
464
+ pipeline = TextToVideoSDPipeline.from_pretrained(
465
+ path,
466
+ unet=unet_out,
467
+ text_encoder=text_encoder_out,
468
+ vae=vae,
469
+ ).to(torch_dtype=torch.float32)
470
+
471
+ lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step)
472
+ lora_manager_temporal.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/temporal', step=global_step)
473
+
474
+ if save_pretrained_model:
475
+ pipeline.save_pretrained(save_path)
476
+
477
+ if is_checkpoint:
478
+ unet, text_encoder = accelerator.prepare(unet, text_encoder)
479
+ models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)]
480
+ [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]
481
+
482
+ logger.info(f"Saved model at {save_path} on step {global_step}")
483
+
484
+ del pipeline
485
+ del unet_out
486
+ del text_encoder_out
487
+ torch.cuda.empty_cache()
488
+ gc.collect()
489
+
490
+
491
+ def main(
492
+ pretrained_model_path: str,
493
+ output_dir: str,
494
+ train_data: Dict,
495
+ validation_data: Dict,
496
+ extra_train_data: list = [],
497
+ dataset_types: Tuple[str] = ('json'),
498
+ validation_steps: int = 100,
499
+ trainable_modules: Tuple[str] = None, # Eg: ("attn1", "attn2")
500
+ extra_unet_params=None,
501
+ train_batch_size: int = 1,
502
+ max_train_steps: int = 500,
503
+ learning_rate: float = 5e-5,
504
+ lr_scheduler: str = "constant",
505
+ lr_warmup_steps: int = 0,
506
+ adam_beta1: float = 0.9,
507
+ adam_beta2: float = 0.999,
508
+ adam_weight_decay: float = 1e-2,
509
+ adam_epsilon: float = 1e-08,
510
+ gradient_accumulation_steps: int = 1,
511
+ gradient_checkpointing: bool = False,
512
+ text_encoder_gradient_checkpointing: bool = False,
513
+ checkpointing_steps: int = 500,
514
+ resume_from_checkpoint: Optional[str] = None,
515
+ resume_step: Optional[int] = None,
516
+ mixed_precision: Optional[str] = "fp16",
517
+ use_8bit_adam: bool = False,
518
+ enable_xformers_memory_efficient_attention: bool = True,
519
+ enable_torch_2_attn: bool = False,
520
+ seed: Optional[int] = None,
521
+ use_offset_noise: bool = False,
522
+ rescale_schedule: bool = False,
523
+ offset_noise_strength: float = 0.1,
524
+ extend_dataset: bool = False,
525
+ cache_latents: bool = False,
526
+ cached_latent_dir=None,
527
+ use_unet_lora: bool = False,
528
+ unet_lora_modules: Tuple[str] = [],
529
+ text_encoder_lora_modules: Tuple[str] = [],
530
+ save_pretrained_model: bool = True,
531
+ lora_rank: int = 16,
532
+ lora_path: str = '',
533
+ lora_unet_dropout: float = 0.1,
534
+ logger_type: str = 'tensorboard',
535
+ **kwargs
536
+ ):
537
+ *_, config = inspect.getargvalues(inspect.currentframe())
538
+
539
+ accelerator = Accelerator(
540
+ gradient_accumulation_steps=gradient_accumulation_steps,
541
+ mixed_precision=mixed_precision,
542
+ log_with=logger_type,
543
+ project_dir=output_dir
544
+ )
545
+
546
+ # Make one log on every process with the configuration for debugging.
547
+ create_logging(logging, logger, accelerator)
548
+
549
+ # Initialize accelerate, transformers, and diffusers warnings
550
+ accelerate_set_verbose(accelerator)
551
+
552
+ # Handle the output folder creation
553
+ if accelerator.is_main_process:
554
+ output_dir = create_output_folders(output_dir, config)
555
+
556
+ # Load scheduler, tokenizer and models.
557
+ noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path)
558
+
559
+ # Freeze any necessary models
560
+ freeze_models([vae, text_encoder, unet])
561
+
562
+ # Enable xformers if available
563
+ handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
564
+
565
+ # Initialize the optimizer
566
+ optimizer_cls = get_optimizer(use_8bit_adam)
567
+
568
+ # Get the training dataset based on types (json, single_video, image)
569
+ train_datasets = get_train_dataset(dataset_types, train_data, tokenizer)
570
+
571
+ # If you have extra train data, you can add a list of however many you would like.
572
+ # Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}]
573
+ try:
574
+ if extra_train_data is not None and len(extra_train_data) > 0:
575
+ for dataset in extra_train_data:
576
+ d_t, t_d = dataset['dataset_types'], dataset['train_data']
577
+ train_datasets += get_train_dataset(d_t, t_d, tokenizer)
578
+
579
+ except Exception as e:
580
+ print(f"Could not process extra train datasets due to an error : {e}")
581
+
582
+ # Extend datasets that are less than the greatest one. This allows for more balanced training.
583
+ attrs = ['train_data', 'frames', 'image_dir', 'video_files']
584
+ extend_datasets(train_datasets, attrs, extend=extend_dataset)
585
+
586
+ # Process one dataset
587
+ if len(train_datasets) == 1:
588
+ train_dataset = train_datasets[0]
589
+
590
+ # Process many datasets
591
+ else:
592
+ train_dataset = torch.utils.data.ConcatDataset(train_datasets)
593
+
594
+ # Create parameters to optimize over with a condition (if "condition" is true, optimize it)
595
+ extra_unet_params = extra_unet_params if extra_unet_params is not None else {}
596
+ extra_text_encoder_params = extra_unet_params if extra_unet_params is not None else {}
597
+
598
+ # Use LoRA if enabled.
599
+ # one temporal lora
600
+ lora_manager_temporal = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["TransformerTemporalModel"])
601
+
602
+ unet_lora_params_temporal, unet_negation_temporal = lora_manager_temporal.add_lora_to_model(
603
+ use_unet_lora, unet, lora_manager_temporal.unet_replace_modules, lora_unet_dropout,
604
+ lora_path + '/temporal/lora/', r=lora_rank)
605
+
606
+ optimizer_temporal = optimizer_cls(
607
+ create_optimizer_params([param_optim(unet_lora_params_temporal, use_unet_lora, is_lora=True,
608
+ extra_params={**{"lr": learning_rate}, **extra_text_encoder_params}
609
+ )], learning_rate),
610
+ lr=learning_rate,
611
+ betas=(adam_beta1, adam_beta2),
612
+ weight_decay=adam_weight_decay,
613
+ eps=adam_epsilon,
614
+ )
615
+
616
+ lr_scheduler_temporal = get_scheduler(
617
+ lr_scheduler,
618
+ optimizer=optimizer_temporal,
619
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
620
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
621
+ )
622
+
623
+ # one spatial lora for each video
624
+ if 'folder' in dataset_types:
625
+ spatial_lora_num = train_dataset.__len__()
626
+ else:
627
+ spatial_lora_num = 1
628
+
629
+ lora_manager_spatials = []
630
+ unet_lora_params_spatial_list = []
631
+ optimizer_spatial_list = []
632
+ lr_scheduler_spatial_list = []
633
+ for i in range(spatial_lora_num):
634
+ lora_manager_spatial = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["Transformer2DModel"])
635
+ lora_manager_spatials.append(lora_manager_spatial)
636
+ unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model(
637
+ use_unet_lora, unet, lora_manager_spatial.unet_replace_modules, lora_unet_dropout,
638
+ lora_path + '/spatial/lora/', r=lora_rank)
639
+
640
+ unet_lora_params_spatial_list.append(unet_lora_params_spatial)
641
+
642
+ optimizer_spatial = optimizer_cls(
643
+ create_optimizer_params([param_optim(unet_lora_params_spatial, use_unet_lora, is_lora=True,
644
+ extra_params={**{"lr": learning_rate}, **extra_text_encoder_params}
645
+ )], learning_rate),
646
+ lr=learning_rate,
647
+ betas=(adam_beta1, adam_beta2),
648
+ weight_decay=adam_weight_decay,
649
+ eps=adam_epsilon,
650
+ )
651
+
652
+ optimizer_spatial_list.append(optimizer_spatial)
653
+
654
+ # Scheduler
655
+ lr_scheduler_spatial = get_scheduler(
656
+ lr_scheduler,
657
+ optimizer=optimizer_spatial,
658
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
659
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
660
+ )
661
+ lr_scheduler_spatial_list.append(lr_scheduler_spatial)
662
+
663
+ unet_negation_all = unet_negation_spatial + unet_negation_temporal
664
+
665
+ # DataLoaders creation:
666
+ train_dataloader = torch.utils.data.DataLoader(
667
+ train_dataset,
668
+ batch_size=train_batch_size,
669
+ shuffle=True
670
+ )
671
+
672
+ # Latents caching
673
+ cached_data_loader = handle_cache_latents(
674
+ cache_latents,
675
+ output_dir,
676
+ train_dataloader,
677
+ train_batch_size,
678
+ vae,
679
+ unet,
680
+ pretrained_model_path,
681
+ validation_data.noise_prior,
682
+ cached_latent_dir,
683
+ )
684
+
685
+ if cached_data_loader is not None:
686
+ train_dataloader = cached_data_loader
687
+
688
+ # Prepare everything with our `accelerator`.
689
+ unet, optimizer_spatial_list, optimizer_temporal, train_dataloader, lr_scheduler_spatial_list, lr_scheduler_temporal, text_encoder = accelerator.prepare(
690
+ unet,
691
+ optimizer_spatial_list, optimizer_temporal,
692
+ train_dataloader,
693
+ lr_scheduler_spatial_list, lr_scheduler_temporal,
694
+ text_encoder
695
+ )
696
+
697
+ # Use Gradient Checkpointing if enabled.
698
+ unet_and_text_g_c(
699
+ unet,
700
+ text_encoder,
701
+ gradient_checkpointing,
702
+ text_encoder_gradient_checkpointing
703
+ )
704
+
705
+ # Enable VAE slicing to save memory.
706
+ vae.enable_slicing()
707
+
708
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
709
+ # as these models are only used for inference, keeping weights in full precision is not required.
710
+ weight_dtype = is_mixed_precision(accelerator)
711
+
712
+ # Move text encoders, and VAE to GPU
713
+ models_to_cast = [text_encoder, vae]
714
+ cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype)
715
+
716
+ # Fix noise schedules to predcit light and dark areas if available.
717
+ if not use_offset_noise and rescale_schedule:
718
+ noise_scheduler.betas = enforce_zero_terminal_snr(noise_scheduler.betas)
719
+
720
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
721
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
722
+
723
+ # Afterwards we recalculate our number of training epochs
724
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
725
+
726
+ # We need to initialize the trackers we use, and also store our configuration.
727
+ # The trackers initializes automatically on the main process.
728
+ if accelerator.is_main_process:
729
+ accelerator.init_trackers("text2video-fine-tune")
730
+
731
+ # Train!
732
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
733
+
734
+ logger.info("***** Running training *****")
735
+ logger.info(f" Num examples = {len(train_dataset)}")
736
+ logger.info(f" Num Epochs = {num_train_epochs}")
737
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
738
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
739
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
740
+ logger.info(f" Total optimization steps = {max_train_steps}")
741
+ global_step = 0
742
+ first_epoch = 0
743
+
744
+ # Only show the progress bar once on each machine.
745
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
746
+ progress_bar.set_description("Steps")
747
+
748
+ def finetune_unet(batch, step, mask_spatial_lora=False, mask_temporal_lora=False):
749
+ nonlocal use_offset_noise
750
+ nonlocal rescale_schedule
751
+
752
+ # Unfreeze UNET Layers
753
+ if global_step == 0:
754
+ already_printed_trainables = False
755
+ unet.train()
756
+ handle_trainable_modules(
757
+ unet,
758
+ trainable_modules,
759
+ is_enabled=True,
760
+ negation=unet_negation_all
761
+ )
762
+
763
+ # Convert videos to latent space
764
+ if not cache_latents:
765
+ latents = tensor_to_vae_latent(batch["pixel_values"], vae)
766
+ else:
767
+ latents = batch["latents"]
768
+
769
+ # Sample noise that we'll add to the latents
770
+ use_offset_noise = use_offset_noise and not rescale_schedule
771
+ noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
772
+ bsz = latents.shape[0]
773
+
774
+ # Sample a random timestep for each video
775
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
776
+ timesteps = timesteps.long()
777
+
778
+ # Add noise to the latents according to the noise magnitude at each timestep
779
+ # (this is the forward diffusion process)
780
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
781
+
782
+ # *Potentially* Fixes gradient checkpointing training.
783
+ # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
784
+ if kwargs.get('eval_train', False):
785
+ unet.eval()
786
+ text_encoder.eval()
787
+
788
+ # Encode text embeddings
789
+ token_ids = batch['prompt_ids']
790
+ encoder_hidden_states = text_encoder(token_ids)[0]
791
+ detached_encoder_state = encoder_hidden_states.clone().detach()
792
+
793
+ # Get the target for loss depending on the prediction type
794
+ if noise_scheduler.config.prediction_type == "epsilon":
795
+ target = noise
796
+
797
+ elif noise_scheduler.config.prediction_type == "v_prediction":
798
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
799
+
800
+ else:
801
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
802
+
803
+ encoder_hidden_states = detached_encoder_state
804
+
805
+ if mask_spatial_lora:
806
+ loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
807
+ for lora_i in loras:
808
+ lora_i.scale = 0.
809
+ loss_spatial = None
810
+ else:
811
+ loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
812
+ for lora_i in loras:
813
+ lora_i.scale = 1.
814
+
815
+ for lora_idx in range(0, len(loras), spatial_lora_num):
816
+ loras[lora_idx + step].scale = 1.
817
+
818
+ loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
819
+ for lora_i in loras:
820
+ lora_i.scale = 0.
821
+
822
+ ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item()
823
+
824
+ if random.uniform(0, 1) < -0.5:
825
+ pixel_values_spatial = transforms.functional.hflip(
826
+ batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1)
827
+ latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae)
828
+ noise_spatial = sample_noise(latents_spatial, offset_noise_strength, use_offset_noise)
829
+ noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps)
830
+ target_spatial = noise_spatial
831
+ model_pred_spatial = unet(noisy_latents_input, timesteps,
832
+ encoder_hidden_states=encoder_hidden_states).sample
833
+ loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
834
+ target_spatial[:, :, 0, :, :].float(), reduction="mean")
835
+ else:
836
+ noisy_latents_input = noisy_latents[:, :, ran_idx, :, :]
837
+ target_spatial = target[:, :, ran_idx, :, :]
838
+ model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps,
839
+ encoder_hidden_states=encoder_hidden_states).sample
840
+ loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
841
+ target_spatial.float(), reduction="mean")
842
+
843
+ if mask_temporal_lora:
844
+ loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
845
+ for lora_i in loras:
846
+ lora_i.scale = 0.
847
+ loss_temporal = None
848
+ else:
849
+ loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
850
+ for lora_i in loras:
851
+ lora_i.scale = 1.
852
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
853
+ loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
854
+
855
+ beta = 1
856
+ alpha = (beta ** 2 + 1) ** 0.5
857
+ ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item()
858
+ model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2)
859
+ target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2)
860
+ loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean")
861
+ loss_temporal = loss_temporal + loss_ad_temporal
862
+
863
+ return loss_spatial, loss_temporal, latents, noise
864
+
865
+ for epoch in range(first_epoch, num_train_epochs):
866
+ train_loss_spatial = 0.0
867
+ train_loss_temporal = 0.0
868
+
869
+ for step, batch in enumerate(train_dataloader):
870
+ # Skip steps until we reach the resumed step
871
+ if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
872
+ if step % gradient_accumulation_steps == 0:
873
+ progress_bar.update(1)
874
+ continue
875
+
876
+ with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
877
+
878
+ text_prompt = batch['text_prompt'][0]
879
+
880
+ for optimizer_spatial in optimizer_spatial_list:
881
+ optimizer_spatial.zero_grad(set_to_none=True)
882
+
883
+ optimizer_temporal.zero_grad(set_to_none=True)
884
+
885
+ mask_temporal_lora = False
886
+ # mask_spatial_lora = False
887
+ mask_spatial_lora = random.uniform(0, 1) < 0.1 and not mask_temporal_lora
888
+
889
+ with accelerator.autocast():
890
+ loss_spatial, loss_temporal, latents, init_noise = finetune_unet(batch, step, mask_spatial_lora=mask_spatial_lora, mask_temporal_lora=mask_temporal_lora)
891
+
892
+ # Gather the losses across all processes for logging (if we use distributed training).
893
+ if not mask_spatial_lora:
894
+ avg_loss_spatial = accelerator.gather(loss_spatial.repeat(train_batch_size)).mean()
895
+ train_loss_spatial += avg_loss_spatial.item() / gradient_accumulation_steps
896
+
897
+ if not mask_temporal_lora:
898
+ avg_loss_temporal = accelerator.gather(loss_temporal.repeat(train_batch_size)).mean()
899
+ train_loss_temporal += avg_loss_temporal.item() / gradient_accumulation_steps
900
+
901
+ # Backpropagate
902
+ if not mask_spatial_lora:
903
+ accelerator.backward(loss_spatial, retain_graph = True)
904
+ optimizer_spatial_list[step].step()
905
+
906
+ if not mask_temporal_lora:
907
+ accelerator.backward(loss_temporal)
908
+ optimizer_temporal.step()
909
+
910
+ lr_scheduler_spatial_list[step].step()
911
+ lr_scheduler_temporal.step()
912
+
913
+ # Checks if the accelerator has performed an optimization step behind the scenes
914
+ if accelerator.sync_gradients:
915
+ progress_bar.update(1)
916
+ global_step += 1
917
+ accelerator.log({"train_loss": train_loss_temporal}, step=global_step)
918
+ train_loss_temporal = 0.0
919
+ if global_step % checkpointing_steps == 0 and global_step > 0:
920
+ save_pipe(
921
+ pretrained_model_path,
922
+ global_step,
923
+ accelerator,
924
+ unet,
925
+ text_encoder,
926
+ vae,
927
+ output_dir,
928
+ lora_manager_spatial,
929
+ lora_manager_temporal,
930
+ unet_lora_modules,
931
+ text_encoder_lora_modules,
932
+ is_checkpoint=True,
933
+ save_pretrained_model=save_pretrained_model
934
+ )
935
+
936
+ if should_sample(global_step, validation_steps, validation_data):
937
+ if accelerator.is_main_process:
938
+ with accelerator.autocast():
939
+ unet.eval()
940
+ text_encoder.eval()
941
+ unet_and_text_g_c(unet, text_encoder, False, False)
942
+ loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
943
+ for lora_i in loras:
944
+ lora_i.scale = validation_data.spatial_scale
945
+
946
+ if validation_data.noise_prior > 0:
947
+ preset_noise = (validation_data.noise_prior) ** 0.5 * batch['inversion_noise'] + (
948
+ 1-validation_data.noise_prior) ** 0.5 * torch.randn_like(batch['inversion_noise'])
949
+ else:
950
+ preset_noise = None
951
+
952
+ pipeline = TextToVideoSDPipeline.from_pretrained(
953
+ pretrained_model_path,
954
+ text_encoder=text_encoder,
955
+ vae=vae,
956
+ unet=unet
957
+ )
958
+
959
+ diffusion_scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
960
+ pipeline.scheduler = diffusion_scheduler
961
+
962
+ prompt_list = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt
963
+ for prompt in prompt_list:
964
+ save_filename = f"{global_step}_{prompt.replace('.', '')}"
965
+
966
+ out_file = f"{output_dir}/samples/{save_filename}.mp4"
967
+
968
+ with torch.no_grad():
969
+ video_frames = pipeline(
970
+ prompt,
971
+ width=validation_data.width,
972
+ height=validation_data.height,
973
+ num_frames=validation_data.num_frames,
974
+ num_inference_steps=validation_data.num_inference_steps,
975
+ guidance_scale=validation_data.guidance_scale,
976
+ latents=preset_noise
977
+ ).frames
978
+ export_to_video(video_frames, out_file, train_data.get('fps', 8))
979
+ logger.info(f"Saved a new sample to {out_file}")
980
+ del pipeline
981
+ torch.cuda.empty_cache()
982
+
983
+ unet_and_text_g_c(
984
+ unet,
985
+ text_encoder,
986
+ gradient_checkpointing,
987
+ text_encoder_gradient_checkpointing
988
+ )
989
+
990
+ accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step)
991
+
992
+ if global_step >= max_train_steps:
993
+ break
994
+
995
+ # Create the pipeline using the trained modules and save it.
996
+ accelerator.wait_for_everyone()
997
+ if accelerator.is_main_process:
998
+ save_pipe(
999
+ pretrained_model_path,
1000
+ global_step,
1001
+ accelerator,
1002
+ unet,
1003
+ text_encoder,
1004
+ vae,
1005
+ output_dir,
1006
+ lora_manager_spatial,
1007
+ lora_manager_temporal,
1008
+ unet_lora_modules,
1009
+ text_encoder_lora_modules,
1010
+ is_checkpoint=False,
1011
+ save_pretrained_model=save_pretrained_model
1012
+ )
1013
+ accelerator.end_training()
1014
+
1015
+
1016
+ if __name__ == "__main__":
1017
+ parser = argparse.ArgumentParser()
1018
+ parser.add_argument("--config", type=str, default='./configs/config_multi_videos.yaml')
1019
+ args = parser.parse_args()
1020
+ main(**OmegaConf.load(args.config))
1021
+
README.md CHANGED
@@ -1,13 +1,364 @@
1
- ---
2
- title: MotionDirector
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.8.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MotionDirector
2
+
3
+ This is the official repository of [MotionDirector](https://showlab.github.io/MotionDirector).
4
+
5
+ **MotionDirector: Motion Customization of Text-to-Video Diffusion Models.**
6
+ <br/>
7
+ [Rui Zhao](https://ruizhaocv.github.io/),
8
+ [Yuchao Gu](https://ycgu.site/),
9
+ [Jay Zhangjie Wu](https://zhangjiewu.github.io/),
10
+ [David Junhao Zhang](https://junhaozhang98.github.io/),
11
+ [Jiawei Liu](https://jia-wei-liu.github.io/),
12
+ [Weijia Wu](https://weijiawu.github.io/),
13
+ [Jussi Keppo](https://www.jussikeppo.com/),
14
+ [Mike Zheng Shou](https://sites.google.com/view/showlab)
15
+ <br/>
16
+
17
+ [![Project Page](https://img.shields.io/badge/Project-Website-orange)](https://showlab.github.io/MotionDirector)
18
+ [![arXiv](https://img.shields.io/badge/arXiv-MotionDirector-b31b1b.svg)](https://arxiv.org/abs/2310.08465)
19
+
20
+ <p align="center">
21
+ <img src="https://github.com/showlab/MotionDirector/blob/page/assets/teaser.gif" width="1080px"/>
22
+ <br>
23
+ <em>MotionDirector can customize text-to-video diffusion models to generate videos with desired motions.</em>
24
+ </p>
25
+
26
+ <table class="center">
27
+ <tr>
28
+ <td style="text-align:center;" colspan="4"><b>Astronaut's daily life on Mars (Motion concepts learned by MotionDirector)</b></td>
29
+ </tr>
30
+ <tr>
31
+ <td style="text-align:center;"><b>Lifting Weights</b></td>
32
+ <td style="text-align:center;"><b>Playing Golf</b></td>
33
+ <td style="text-align:center;"><b>Riding Horse</b></td>
34
+ <td style="text-align:center;"><b>Riding Bicycle</b></td>
35
+ </tr>
36
+ <tr>
37
+ <td><img src=assets/astronaut_mars/An_astronaut_is_lifting_weights_on_Mars_4K_high_quailty_highly_detailed_4008521.gif></td>
38
+ <td><img src=assets/astronaut_mars/Astronaut_playing_golf_on_Mars_659514.gif></td>
39
+ <td><img src=assets/astronaut_mars/An_astronaut_is_riding_a_horse_on_Mars_4K_high_quailty_highly_detailed_1913261.gif></td>
40
+ <td><img src=assets/astronaut_mars/An_astronaut_is_riding_a_bicycle_past_the_pyramids_Mars_4K_high_quailty_highly_detailed_5532778.gif></td>
41
+ </tr>
42
+ <tr>
43
+ <td width=25% style="text-align:center;">"An astronaut is lifting weights on Mars, 4K, high quailty, highly detailed.” </br> seed: 4008521</td>
44
+ <td width=25% style="text-align:center;">"Astronaut playing golf on Mars” </br> seed: 659514</td>
45
+ <td width=25% style="text-align:center;">"An astronaut is riding a horse on Mars, 4K, high quailty, highly detailed." </br> seed: 1913261</td>
46
+ <td width=25% style="text-align:center;">"An astronaut is riding a bicycle past the pyramids Mars, 4K, high quailty, highly detailed." </br> seed: 5532778</td>
47
+ <tr>
48
+ </table>
49
+
50
+ ## News
51
+ - [2023.12.06] [MotionDirector for Sports](#MotionDirector_for_Sports) released! Lifting weights, riding horse, palying golf, etc.
52
+ - [2023.12.05] [Colab demo](https://github.com/camenduru/MotionDirector-colab) is available. Thanks to [Camenduru](https://twitter.com/camenduru).
53
+ - [2023.12.04] [MotionDirector for Cinematic Shots](#MotionDirector_for_Cinematic_Shots) released. Now, you can make AI films with professional cinematic shots!
54
+ - [2023.12.02] Code and model weights released!
55
+
56
+ ## ToDo
57
+ - [ ] Gradio Demo
58
+ - [ ] More trained weights of MotionDirector
59
+
60
+ ## Setup
61
+ ### Requirements
62
+
63
+ ```shell
64
+ # create virtual environment
65
+ conda create -n motiondirector python=3.8
66
+ conda activate motiondirector
67
+ # install packages
68
+ pip install -r requirements.txt
69
+ ```
70
+
71
+ ### Weights of Foundation Models
72
+ ```shell
73
+ git lfs install
74
+ ## You can choose the ModelScopeT2V or ZeroScope, etc., as the foundation model.
75
+ ## ZeroScope
76
+ git clone https://huggingface.co/cerspense/zeroscope_v2_576w ./models/zeroscope_v2_576w/
77
+ ## ModelScopeT2V
78
+ git clone https://huggingface.co/damo-vilab/text-to-video-ms-1.7b ./models/model_scope/
79
+ ```
80
+ ### Weights of trained MotionDirector <a name="download_weights"></a>
81
+ ```shell
82
+ # Make sure you have git-lfs installed (https://git-lfs.com)
83
+ git lfs install
84
+ git clone https://huggingface.co/ruizhaocv/MotionDirector_weights ./outputs
85
+ ```
86
+
87
+ ## Usage
88
+ ### Training
89
+
90
+ #### Train MotionDirector on multiple videos:
91
+ ```bash
92
+ python MotionDirector_train.py --config ./configs/config_multi_videos.yaml
93
+ ```
94
+ #### Train MotionDirector on a single video:
95
+ ```bash
96
+ python MotionDirector_train.py --config ./configs/config_single_video.yaml
97
+ ```
98
+
99
+ Note:
100
+ - Before running the above command,
101
+ make sure you replace the path to foundational model weights and training data with your own in the config files `config_multi_videos.yaml` or `config_single_video.yaml`.
102
+ - Generally, training on multiple 16-frame videos usually takes `300~500` steps, about `9~16` minutes using one A5000 GPU. Training on a single video takes `50~150` steps, about `1.5~4.5` minutes using one A5000 GPU. The required VRAM for training is around `14GB`.
103
+ - Reduce `n_sample_frames` if your GPU memory is limited.
104
+ - Reduce the learning rate and increase the training steps for better performance.
105
+
106
+
107
+ ### Inference
108
+ ```bash
109
+ python MotionDirector_inference.py --model /path/to/the/foundation/model --prompt "Your prompt" --checkpoint_folder /path/to/the/trained/MotionDirector --checkpoint_index 300 --noise_prior 0.
110
+ ```
111
+ Note:
112
+ - Replace `/path/to/the/foundation/model` with your own path to the foundation model, like ZeroScope.
113
+ - The value of `checkpoint_index` means the checkpoint saved at which the training step is selected.
114
+ - The value of `noise_prior` indicates how much the inversion noise of the reference video affects the generation.
115
+ We recommend setting it to `0` for MotionDirector trained on multiple videos to achieve the highest diverse generation, while setting it to `0.1~0.5` for MotionDirector trained on a single video for faster convergence and better alignment with the reference video.
116
+
117
+
118
+ ## Inference with pre-trained MotionDirector
119
+ All available weights are at official [Huggingface Repo](https://huggingface.co/ruizhaocv/MotionDirector_weights).
120
+ Run the [download command](#download_weights), the weights will be downloaded to the folder `outputs`, then run the following inference command to generate videos.
121
+
122
+ ### MotionDirector trained on multiple videos:
123
+ ```bash
124
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A person is riding a bicycle past the Eiffel Tower." --checkpoint_folder ./outputs/train/riding_bicycle/ --checkpoint_index 300 --noise_prior 0. --seed 7192280
125
+ ```
126
+ Note:
127
+ - Replace `/path/to/the/ZeroScope` with your own path to the foundation model, i.e. the ZeroScope.
128
+ - Change the `prompt` to generate different videos.
129
+ - The `seed` is set to a random value by default. Set it to a specific value will obtain certain results, as provided in the table below.
130
+
131
+ Results:
132
+
133
+ <table class="center">
134
+ <tr>
135
+ <td style="text-align:center;"><b>Reference Videos</b></td>
136
+ <td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
137
+ </tr>
138
+ <tr>
139
+ <td><img src=assets/multi_videos_results/reference_videos.gif></td>
140
+ <td><img src=assets/multi_videos_results/A_person_is_riding_a_bicycle_past_the_Eiffel_Tower_7192280.gif></td>
141
+ <td><img src=assets/multi_videos_results/A_panda_is_riding_a_bicycle_in_a_garden_2178639.gif></td>
142
+ <td><img src=assets/multi_videos_results/An_alien_is_riding_a_bicycle_on_Mars_2390886.gif></td>
143
+ </tr>
144
+ <tr>
145
+ <td width=25% style="text-align:center;color:gray;">"A person is riding a bicycle."</td>
146
+ <td width=25% style="text-align:center;">"A person is riding a bicycle past the Eiffel Tower.” </br> seed: 7192280</td>
147
+ <td width=25% style="text-align:center;">"A panda is riding a bicycle in a garden." </br> seed: 2178639</td>
148
+ <td width=25% style="text-align:center;">"An alien is riding a bicycle on Mars." </br> seed: 2390886</td>
149
+ </table>
150
+
151
+ ### MotionDirector trained on a single video:
152
+ 16 frames:
153
+ ```bash
154
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A tank is running on the moon." --checkpoint_folder ./outputs/train/car_16/ --checkpoint_index 150 --noise_prior 0.5 --seed 8551187
155
+ ```
156
+ <table class="center">
157
+ <tr>
158
+ <td style="text-align:center;"><b>Reference Video</b></td>
159
+ <td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
160
+ </tr>
161
+ <tr>
162
+ <td><img src=assets/single_video_results/reference_video.gif></td>
163
+ <td><img src=assets/single_video_results/A_tank_is_running_on_the_moon_8551187.gif></td>
164
+ <td><img src=assets/single_video_results/A_lion_is_running_past_the_pyramids_431554.gif></td>
165
+ <td><img src=assets/single_video_results/A_spaceship_is_flying_past_Mars_8808231.gif></td>
166
+ </tr>
167
+ <tr>
168
+ <td width=25% style="text-align:center;color:gray;">"A car is running on the road."</td>
169
+ <td width=25% style="text-align:center;">"A tank is running on the moon.” </br> seed: 8551187</td>
170
+ <td width=25% style="text-align:center;">"A lion is running past the pyramids." </br> seed: 431554</td>
171
+ <td width=25% style="text-align:center;">"A spaceship is flying past Mars." </br> seed: 8808231</td>
172
+ </tr>
173
+ </table>
174
+
175
+ 24 frames:
176
+ ```bash
177
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A truck is running past the Arc de Triomphe." --checkpoint_folder ./outputs/train/car_24/ --checkpoint_index 150 --noise_prior 0.5 --width 576 --height 320 --num-frames 24 --seed 34543
178
+ ```
179
+ <table class="center">
180
+ <tr>
181
+ <td style="text-align:center;"><b>Reference Video</b></td>
182
+ <td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
183
+ </tr>
184
+ <tr>
185
+ <td><img src=assets/single_video_results/24_frames/reference_video.gif></td>
186
+ <td><img src=assets/single_video_results/24_frames/A_truck_is_running_past_the_Arc_de_Triomphe_34543.gif></td>
187
+ <td><img src=assets/single_video_results/24_frames/An_elephant_is_running_in_a_forest_2171736.gif></td>
188
+ </tr>
189
+ <tr>
190
+ <td width=25% style="text-align:center;color:gray;">"A car is running on the road."</td>
191
+ <td width=25% style="text-align:center;">"A truck is running past the Arc de Triomphe.” </br> seed: 34543</td>
192
+ <td width=25% style="text-align:center;">"An elephant is running in a forest." </br> seed: 2171736</td>
193
+ </tr>
194
+ <tr>
195
+ <td><img src=assets/single_video_results/24_frames/reference_video.gif></td>
196
+ <td><img src=assets/single_video_results/24_frames/A_person_on_a_camel_is_running_past_the_pyramids_4904126.gif></td>
197
+ <td><img src=assets/single_video_results/24_frames/A_spacecraft_is_flying_past_the_Milky_Way_galaxy_3235677.gif></td>
198
+ </tr>
199
+ <tr>
200
+ <td width=25% style="text-align:center;color:gray;">"A car is running on the road."</td>
201
+ <td width=25% style="text-align:center;">"A person on a camel is running past the pyramids." </br> seed: 4904126</td>
202
+ <td width=25% style="text-align:center;">"A spacecraft is flying past the Milky Way galaxy." </br> seed: 3235677</td>
203
+ </tr>
204
+ </table>
205
+
206
+ ## MotionDirector for Sports <a name="MotionDirector_for_Sports"></a>
207
+
208
+ ```bash
209
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A panda is lifting weights in a garden." --checkpoint_folder ./outputs/train/lifting_weights/ --checkpoint_index 300 --noise_prior 0. --seed 9365597
210
+ ```
211
+ <table class="center">
212
+ <tr>
213
+ <td style="text-align:center;" colspan="4"><b>Videos Generated by MotionDirector</b></td>
214
+ </tr>
215
+ <tr>
216
+ <td style="text-align:center;" colspan="2"><b>Lifting Weights</b></td>
217
+ <td style="text-align:center;" colspan="2"><b>Riding Bicycle</b></td>
218
+ </tr>
219
+ <tr>
220
+ <td><img src=assets/sports_results/lifting_weights/A_panda_is_lifting_weights_in_a_garden_1699276.gif></td>
221
+ <td><img src=assets/sports_results/lifting_weights/A_police_officer_is_lifting_weights_in_front_of_the_police_station_6804745.gif></td>
222
+ <td><img src=assets/multi_videos_results/A_panda_is_riding_a_bicycle_in_a_garden_2178639.gif></td>
223
+ <td><img src=assets/multi_videos_results/An_alien_is_riding_a_bicycle_on_Mars_2390886.gif></td>
224
+ </tr>
225
+ <tr>
226
+ <td width=25% style="text-align:center;">"A panda is lifting weights in a garden.” </br> seed: 1699276</td>
227
+ <td width=25% style="text-align:center;">"A police officer is lifting weights in front of the police station.” </br> seed: 6804745</td>
228
+ <td width=25% style="text-align:center;">"A panda is riding a bicycle in a garden." </br> seed: 2178639</td>
229
+ <td width=25% style="text-align:center;">"An alien is riding a bicycle on Mars." </br> seed: 2390886</td>
230
+ <tr>
231
+ <td style="text-align:center;" colspan="2"><b>Riding Horse</b></td>
232
+ <td style="text-align:center;" colspan="2"><b>Playing Golf</b></td>
233
+ </tr>
234
+ <tr>
235
+ <td><img src=assets/sports_results/riding_horse/A_Royal_Guard_riding_a_horse_in_front_of_Buckingham_Palace_4490970.gif></td>
236
+ <td><img src=assets/sports_results/riding_horse/A_man_riding_an_elephant_through_the_jungle_6230765.gif></td>
237
+ <td><img src=assets/sports_results/playing_golf/A_man_is_playing_golf_in_front_of_the_White_House_8870450.gif></td>
238
+ <td><img src=assets/sports_results/playing_golf/A_monkey_is_playing_golf_on_a_field_full_of_flowers_2989633.gif></td>
239
+ </tr>
240
+ <tr>
241
+ <td width=25% style="text-align:center;">"A Royal Guard riding a horse in front of Buckingham Palace.” </br> seed: 4490970</td>
242
+ <td width=25% style="text-align:center;">"A man riding an elephant through the jungle.” </br> seed: 6230765</td>
243
+ <td width=25% style="text-align:center;">"A man is playing golf in front of the White House." </br> seed: 8870450</td>
244
+ <td width=25% style="text-align:center;">"A monkey is playing golf on a field full of flowers." </br> seed: 2989633</td>
245
+ <tr>
246
+ </table>
247
+
248
+ More sports, to be continued ...
249
+
250
+ ## MotionDirector for Cinematic Shots <a name="MotionDirector_for_Cinematic_Shots"></a>
251
+
252
+ ### 1. Zoom
253
+ #### 1.1 Dolly Zoom (Hitchcockian Zoom)
254
+ ```bash
255
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A firefighter standing in front of a burning forest captured with a dolly zoom." --checkpoint_folder ./outputs/train/dolly_zoom/ --checkpoint_index 150 --noise_prior 0.5 --seed 9365597
256
+ ```
257
+ <table class="center">
258
+ <tr>
259
+ <td style="text-align:center;"><b>Reference Video</b></td>
260
+ <td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
261
+ </tr>
262
+ <tr>
263
+ <td><img src=assets/cinematic_shots_results/dolly_zoom_16.gif></td>
264
+ <td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_dolly_zoom_9365597.gif></td>
265
+ <td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_dolly_zoom_1675932.gif></td>
266
+ <td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_dolly_zoom_2310805.gif></td>
267
+ </tr>
268
+ <tr>
269
+ <td width=25% style="text-align:center;color:gray;">"A man standing in room captured with a dolly zoom."</td>
270
+ <td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a dolly zoom." </br> seed: 9365597 </br> noise_prior: 0.5</td>
271
+ <td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a dolly zoom." </br> seed: 1675932 </br> noise_prior: 0.5</td>
272
+ <td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a dolly zoom." </br> seed: 2310805 </br> noise_prior: 0.5 </td>
273
+ </tr>
274
+ <tr>
275
+ <td><img src=assets/cinematic_shots_results/dolly_zoom_16.gif></td>
276
+ <td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_dolly_zoom_4615820.gif></td>
277
+ <td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_dolly_zoom_4114896.gif></td>
278
+ <td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_dolly_zoom_7492004.gif></td>
279
+ </tr>
280
+ <tr>
281
+ <td width=25% style="text-align:center;color:gray;">"A man standing in room captured with a dolly zoom."</td>
282
+ <td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a dolly zoom." </br> seed: 4615820 </br> noise_prior: 0.3</td>
283
+ <td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a dolly zoom." </br> seed: 4114896 </br> noise_prior: 0.3</td>
284
+ <td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a dolly zoom." </br> seed: 7492004</td>
285
+ </tr>
286
+ </table>
287
+
288
+ #### 1.2 Zoom In
289
+ The reference video is shot with my own water cup. You can also pick up your cup or any other object to practice camera movements and turn it into imaginative videos. Create your AI films with customized camera movements!
290
+
291
+ ```bash
292
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A firefighter standing in front of a burning forest captured with a zoom in." --checkpoint_folder ./outputs/train/zoom_in/ --checkpoint_index 150 --noise_prior 0.3 --seed 1429227
293
+ ```
294
+ <table class="center">
295
+ <tr>
296
+ <td style="text-align:center;"><b>Reference Video</b></td>
297
+ <td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
298
+ </tr>
299
+ <tr>
300
+ <td><img src=assets/cinematic_shots_results/zoom_in_16.gif></td>
301
+ <td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_zoom_in_1429227.gif></td>
302
+ <td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_zoom_in_487239.gif></td>
303
+ <td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_zoom_in_1393184.gif></td>
304
+ </tr>
305
+ <tr>
306
+ <td width=25% style="text-align:center;color:gray;">"A cup in a lab captured with a zoom in."</td>
307
+ <td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a zoom in." </br> seed: 1429227</td>
308
+ <td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a zoom in." </br> seed: 487239 </td>
309
+ <td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a zoom in." </br> seed: 1393184</td>
310
+ </tr>
311
+ </table>
312
+
313
+ #### 1.3 Zoom Out
314
+ ```bash
315
+ python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A firefighter standing in front of a burning forest captured with a zoom out." --checkpoint_folder ./outputs/train/zoom_out/ --checkpoint_index 150 --noise_prior 0.3 --seed 4971910
316
+ ```
317
+ <table class="center">
318
+ <tr>
319
+ <td style="text-align:center;"><b>Reference Video</b></td>
320
+ <td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
321
+ </tr>
322
+ <tr>
323
+ <td><img src=assets/cinematic_shots_results/zoom_out_16.gif></td>
324
+ <td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_zoom_out_4971910.gif></td>
325
+ <td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_zoom_out_1767994.gif></td>
326
+ <td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_zoom_out_8203639.gif></td>
327
+ </tr>
328
+ <tr>
329
+ <td width=25% style="text-align:center;color:gray;">"A cup in a lab captured with a zoom out."</td>
330
+ <td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a zoom out." </br> seed: 4971910</td>
331
+ <td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a zoom out." </br> seed: 1767994 </td>
332
+ <td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a zoom out." </br> seed: 8203639</td>
333
+ </tr>
334
+ </table>
335
+
336
+ More Cinematic Shots, to be continued ....
337
+
338
+ ## More results
339
+
340
+ If you have a more impressive MotionDirector or generated videos, please feel free to open an issue and share them with us. We would greatly appreciate it.
341
+ Improvements to the code are also highly welcome.
342
+
343
+ Please refer to [Project Page](https://showlab.github.io/MotionDirector) for more results.
344
+
345
+
346
+ ## Citation
347
+
348
+
349
+ ```bibtex
350
+
351
+ @article{zhao2023motiondirector,
352
+ title={MotionDirector: Motion Customization of Text-to-Video Diffusion Models},
353
+ author={Zhao, Rui and Gu, Yuchao and Wu, Jay Zhangjie and Zhang, David Junhao and Liu, Jiawei and Wu, Weijia and Keppo, Jussi and Shou, Mike Zheng},
354
+ journal={arXiv preprint arXiv:2310.08465},
355
+ year={2023}
356
+ }
357
+
358
+ ```
359
+
360
+ ## Shoutouts
361
+
362
+ - This code builds on [diffusers](https://github.com/huggingface/diffusers) and [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning). Thanks for open-sourcing!
363
+ - Thanks to [camenduru](https://twitter.com/camenduru) for the [colab demo](https://github.com/camenduru/MotionDirector-colab).
364
+ - Thanks to [yhyu13](https://github.com/yhyu13) for the [Huggingface Repo](https://huggingface.co/Yhyu13/MotionDirector_LoRA).
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from demo.motiondirector import MotionDirector
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+ snapshot_download(repo_id="cerspense/zeroscope_v2_576w", local_dir="./zeroscope_v2_576w/")
8
+ snapshot_download(repo_id="ruizhaocv/MotionDirector", local_dir="./MotionDirector_pretrained")
9
+
10
+ is_spaces = True if "SPACE_ID" in os.environ else False
11
+ true_for_shared_ui = False # This will be true only if you are in a shared UI
12
+ if (is_spaces):
13
+ true_for_shared_ui = True if "ruizhaocv/MotionDirector" in os.environ['SPACE_ID'] else False
14
+
15
+
16
+
17
+ runner = MotionDirector()
18
+
19
+
20
+ def motiondirector(model_select, text_pormpt, neg_text_pormpt, random_seed=1, steps=25, guidance_scale=7.5, baseline_select=False):
21
+ return runner(model_select, text_pormpt, neg_text_pormpt, int(random_seed) if random_seed != "" else 1, int(steps), float(guidance_scale), baseline_select)
22
+
23
+
24
+ with gr.Blocks() as demo:
25
+ gr.HTML(
26
+ """
27
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
28
+ <a href="https://github.com/showlab/MotionDirector" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
29
+ </a>
30
+ <div>
31
+ <h1 >MotionDirector: Motion Customization of Text-to-Video Diffusion Models</h1>
32
+ <h5 style="margin: 0;">More MotionDirectors are on the way. Stay tuned 🔥! Give us a star ✨ on Github for the latest update.</h5>
33
+ </br>
34
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
35
+ <a href="https://arxiv.org/abs/2310.08465"><img src="https://img.shields.io/badge/arXiv-MotionDirector-b31b1b.svg"></a>&nbsp;&nbsp;
36
+ <a href='https://showlab.github.io/MotionDirector'><img src='https://img.shields.io/badge/Project_Page-MotionDirector-green'></a>&nbsp;&nbsp;
37
+ <a href='https://github.com/showlab/MotionDirector'><img src='https://img.shields.io/badge/Github-MotionDirector-blue'></a>&nbsp;&nbsp;
38
+ </div>
39
+ </div>
40
+ </div>
41
+ """)
42
+ with gr.Row():
43
+ generated_video_baseline = gr.Video(format="mp4", label="Video Generated by base model (ZeroScope with same seed)")
44
+ generated_video = gr.Video(format="mp4", label="Video Generated by MotionDirector")
45
+
46
+ with gr.Column():
47
+ baseline_select = gr.Checkbox(label="Compare with baseline (ZeroScope with same seed)", info="Run baseline? Note: Inference time will be doubled.")
48
+ random_seed = gr.Textbox(label="Random seed", value=1, info="default: 1")
49
+ sampling_steps = gr.Textbox(label="Sampling steps", value=30, info="default: 30")
50
+ guidance_scale = gr.Textbox(label="Guidance scale", value=12, info="default: 12")
51
+
52
+ with gr.Row():
53
+ model_select = gr.Dropdown(
54
+ ["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)",
55
+ "1-2: [Cinematic Shots] -- Zoom In",
56
+ "1-3: [Cinematic Shots] -- Zoom Out",
57
+ "2-1: [Object Trajectory] -- Right to Left",
58
+ "2-2: [Object Trajectory] -- Left to Right",
59
+ "3-1: [Sports Concepts] -- Riding Bicycle",
60
+ "3-2: [Sports Concepts] -- Riding Horse",
61
+ "3-3: [Sports Concepts] -- Lifting Weights",
62
+ "3-4: [Sports Concepts] -- Playing Golf"
63
+ ],
64
+ label="MotionDirector",
65
+ info="Which MotionDirector would you like to use!"
66
+ )
67
+
68
+ text_pormpt = gr.Textbox(label="Text Prompt", value='', placeholder="Input your text prompt here!")
69
+ neg_text_pormpt = gr.Textbox(label="Negative Text Prompt", value='', placeholder="default: None")
70
+
71
+ submit = gr.Button("Generate")
72
+
73
+ # when the `submit` button is clicked
74
+ submit.click(
75
+ motiondirector,
76
+ [model_select, text_pormpt, neg_text_pormpt, random_seed, sampling_steps, guidance_scale, baseline_select],
77
+ [generated_video, generated_video_baseline]
78
+ )
79
+
80
+ # Examples
81
+ gr.Markdown("## Examples")
82
+ gr.Examples(
83
+ fn=motiondirector,
84
+ examples=[
85
+ ["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)", "A lion sitting on top of a cliff captured with a dolly zoom.", 1675932],
86
+ ["1-2: [Cinematic Shots] -- Zoom In", "A firefighter standing in front of a burning forest captured with a zoom in.", 1429227],
87
+ ["1-3: [Cinematic Shots] -- Zoom Out", "A lion sitting on top of a cliff captured with a zoom out.", 1767994],
88
+ ["2-1: [Object Trajectory] -- Right to Left", "A tank is running on the moon.", 8551187],
89
+ ["2-2: [Object Trajectory] -- Left to Right", "A tiger is running in the forest.", 3463673],
90
+ ["3-1: [Sports Concepts] -- Riding Bicycle", "An astronaut is riding a bicycle past the pyramids Mars 4K high quailty highly detailed.", 4422954],
91
+ ["3-2: [Sports Concepts] -- Riding Horse", "A man riding an elephant through the jungle.", 6230765],
92
+ ["3-3: [Sports Concepts] -- Lifting Weights", "A panda is lifting weights in a garden.", 1699276],
93
+ ["3-4: [Sports Concepts] -- Playing Golf", "A man is playing golf in front of the White House.", 8870450],
94
+ ],
95
+ inputs=[model_select, text_pormpt, random_seed],
96
+ outputs=generated_video,
97
+ )
98
+
99
+ demo.queue(max_size=15)
100
+ demo.launch(share=True)
demo/MotionDirector_gradio.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import imageio
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ from demo.motiondirector import MotionDirector
8
+
9
+ runner = MotionDirector()
10
+
11
+
12
+ def motiondirector(model_select, text_pormpt, neg_text_pormpt, random_seed=1, steps=25, guidance_scale=7.5, baseline_select=False):
13
+ return runner(model_select, text_pormpt, neg_text_pormpt, int(random_seed) if random_seed != "" else 1, int(steps), float(guidance_scale), baseline_select)
14
+
15
+
16
+ with gr.Blocks() as demo:
17
+ gr.HTML(
18
+ """
19
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
20
+ <a href="https://github.com/showlab/MotionDirector" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
21
+ </a>
22
+ <div>
23
+ <h1 >MotionDirector: Motion Customization of Text-to-Video Diffusion Models</h1>
24
+ <h5 style="margin: 0;">More MotionDirectors are on the way. Stay tuned 🔥! Give us a star ✨ on Github for the latest update.</h5>
25
+ </br>
26
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
27
+ <a href="https://arxiv.org/abs/2310.08465"><img src="https://img.shields.io/badge/arXiv-MotionDirector-b31b1b.svg"></a>&nbsp;&nbsp;
28
+ <a href='https://showlab.github.io/MotionDirector'><img src='https://img.shields.io/badge/Project_Page-MotionDirector-green'></a>&nbsp;&nbsp;
29
+ <a href='https://github.com/showlab/MotionDirector'><img src='https://img.shields.io/badge/Github-MotionDirector-blue'></a>&nbsp;&nbsp;
30
+ </div>
31
+ </div>
32
+ </div>
33
+ """)
34
+ with gr.Row():
35
+ generated_video_baseline = gr.Video(format="mp4", label="Video Generated by base model (ZeroScope with same seed)")
36
+ generated_video = gr.Video(format="mp4", label="Video Generated by MotionDirector")
37
+
38
+ with gr.Column():
39
+ baseline_select = gr.Checkbox(label="Compare with baseline (ZeroScope with same seed)", info="Run baseline? Note: Inference time will be doubled.")
40
+ random_seed = gr.Textbox(label="Random seed", value=1, info="default: 1")
41
+ sampling_steps = gr.Textbox(label="Sampling steps", value=30, info="default: 30")
42
+ guidance_scale = gr.Textbox(label="Guidance scale", value=12, info="default: 12")
43
+
44
+ with gr.Row():
45
+ model_select = gr.Dropdown(
46
+ ["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)",
47
+ "1-2: [Cinematic Shots] -- Zoom In",
48
+ "1-3: [Cinematic Shots] -- Zoom Out",
49
+ "2-1: [Object Trajectory] -- Right to Left",
50
+ "2-2: [Object Trajectory] -- Left to Right",
51
+ "3-1: [Sports Concepts] -- Riding Bicycle",
52
+ "3-2: [Sports Concepts] -- Riding Horse",
53
+ "3-3: [Sports Concepts] -- Lifting Weights",
54
+ "3-4: [Sports Concepts] -- Playing Golf"
55
+ ],
56
+ label="MotionDirector",
57
+ info="Which MotionDirector would you like to use!"
58
+ )
59
+
60
+ text_pormpt = gr.Textbox(label="Text Prompt", value='', placeholder="Input your text prompt here!")
61
+ neg_text_pormpt = gr.Textbox(label="Negative Text Prompt", value='', placeholder="default: None")
62
+
63
+ submit = gr.Button("Generate")
64
+
65
+ # when the `submit` button is clicked
66
+ submit.click(
67
+ motiondirector,
68
+ [model_select, text_pormpt, neg_text_pormpt, random_seed, sampling_steps, guidance_scale, baseline_select],
69
+ [generated_video, generated_video_baseline]
70
+ )
71
+
72
+ # Examples
73
+ gr.Markdown("## Examples")
74
+ gr.Examples(
75
+ fn=motiondirector,
76
+ examples=[
77
+ ["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)", "A lion sitting on top of a cliff captured with a dolly zoom.", 1675932],
78
+ ["1-2: [Cinematic Shots] -- Zoom In", "A firefighter standing in front of a burning forest captured with a zoom in.", 1429227],
79
+ ["1-3: [Cinematic Shots] -- Zoom Out", "A lion sitting on top of a cliff captured with a zoom out.", 1767994],
80
+ ["2-1: [Object Trajectory] -- Right to Left", "A tank is running on the moon.", 8551187],
81
+ ["2-2: [Object Trajectory] -- Left to Right", "A tiger is running in the forest.", 3463673],
82
+ ["3-1: [Sports Concepts] -- Riding Bicycle", "An astronaut is riding a bicycle past the pyramids Mars 4K high quailty highly detailed.", 4422954],
83
+ ["3-2: [Sports Concepts] -- Riding Horse", "A man riding an elephant through the jungle.", 6230765],
84
+ ["3-3: [Sports Concepts] -- Lifting Weights", "A panda is lifting weights in a garden.", 1699276],
85
+ ["3-4: [Sports Concepts] -- Playing Golf", "A man is playing golf in front of the White House.", 8870450],
86
+ ],
87
+ inputs=[model_select, text_pormpt, random_seed],
88
+ outputs=generated_video,
89
+ )
90
+
91
+ demo.queue(max_size=15)
92
+ demo.launch(share=True)
demo/motiondirector.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
7
+ from einops import rearrange
8
+ from torch import Tensor
9
+ from torch.nn.functional import interpolate
10
+ from tqdm import trange
11
+ import random
12
+
13
+ from MotionDirector_train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
14
+ from utils.lora_handler import LoraHandler
15
+ from utils.ddim_utils import ddim_inversion
16
+ from utils.lora import extract_lora_child_module
17
+ import imageio
18
+
19
+
20
+ def initialize_pipeline(
21
+ model: str,
22
+ device: str = "cuda",
23
+ xformers: bool = True,
24
+ sdp: bool = True,
25
+ lora_path: str = "",
26
+ lora_rank: int = 32,
27
+ lora_scale: float = 1.0,
28
+ ):
29
+ with warnings.catch_warnings():
30
+ warnings.simplefilter("ignore")
31
+
32
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
33
+
34
+ # Freeze any necessary models
35
+ freeze_models([vae, text_encoder, unet])
36
+
37
+ # Enable xformers if available
38
+ handle_memory_attention(xformers, sdp, unet)
39
+
40
+ lora_manager_temporal = LoraHandler(
41
+ version="cloneofsimo",
42
+ use_unet_lora=True,
43
+ use_text_lora=False,
44
+ save_for_webui=False,
45
+ only_for_webui=False,
46
+ unet_replace_modules=["TransformerTemporalModel"],
47
+ text_encoder_replace_modules=None,
48
+ lora_bias=None
49
+ )
50
+
51
+ unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
52
+ True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
53
+
54
+ unet.eval()
55
+ text_encoder.eval()
56
+ unet_and_text_g_c(unet, text_encoder, False, False)
57
+
58
+ pipe = TextToVideoSDPipeline.from_pretrained(
59
+ pretrained_model_name_or_path=model,
60
+ scheduler=scheduler,
61
+ tokenizer=tokenizer,
62
+ text_encoder=text_encoder.to(device=device, dtype=torch.half),
63
+ vae=vae.to(device=device, dtype=torch.half),
64
+ unet=unet.to(device=device, dtype=torch.half),
65
+ )
66
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
67
+
68
+ return pipe
69
+
70
+
71
+ def inverse_video(pipe, latents, num_steps):
72
+ ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
73
+ ddim_inv_scheduler.set_timesteps(num_steps)
74
+
75
+ ddim_inv_latent = ddim_inversion(
76
+ pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
77
+ num_inv_steps=num_steps, prompt="")[-1]
78
+ return ddim_inv_latent
79
+
80
+
81
+ def prepare_input_latents(
82
+ pipe: TextToVideoSDPipeline,
83
+ batch_size: int,
84
+ num_frames: int,
85
+ height: int,
86
+ width: int,
87
+ latents_path:str,
88
+ noise_prior: float
89
+ ):
90
+ # initialize with random gaussian noise
91
+ scale = pipe.vae_scale_factor
92
+ shape = (batch_size, pipe.unet.config.in_channels, num_frames, height // scale, width // scale)
93
+ if noise_prior > 0.:
94
+ cached_latents = torch.load(latents_path)
95
+ if 'inversion_noise' not in cached_latents:
96
+ latents = inverse_video(pipe, cached_latents['latents'].unsqueeze(0), 50).squeeze(0)
97
+ else:
98
+ latents = torch.load(latents_path)['inversion_noise'].unsqueeze(0)
99
+ if latents.shape[0] != batch_size:
100
+ latents = latents.repeat(batch_size, 1, 1, 1, 1)
101
+ if latents.shape != shape:
102
+ latents = interpolate(rearrange(latents, "b c f h w -> (b f) c h w", b=batch_size), (height // scale, width // scale), mode='bilinear')
103
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", b=batch_size)
104
+ noise = torch.randn_like(latents, dtype=torch.half)
105
+ latents_base = noise
106
+ latents = (noise_prior) ** 0.5 * latents + (1 - noise_prior) ** 0.5 * noise
107
+ else:
108
+ latents = torch.randn(shape, dtype=torch.half)
109
+ latents_base = latents
110
+
111
+ return latents, latents_base
112
+
113
+
114
+ class MotionDirector():
115
+ def __init__(self):
116
+ self.version = "0.0.0"
117
+ self.foundation_model_path = "./zeroscope_v2_576w/"
118
+ self.lora_path = "./MotionDirector_pretrained/dolly_zoom_(hitchcockian_zoom)/checkpoint-default/temporal/lora"
119
+ with torch.autocast("cuda", dtype=torch.half):
120
+ self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path, lora_scale=1)
121
+
122
+ def reload_lora(self, lora_path):
123
+ if lora_path != self.lora_path:
124
+ self.lora_path = lora_path
125
+ with torch.autocast("cuda", dtype=torch.half):
126
+ self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path)
127
+
128
+ def __call__(self, model_select, text_pormpt, neg_text_pormpt, random_seed, steps, guidance_scale, baseline_select):
129
+ model_select = str(model_select)
130
+ out_name = f"./outputs/inference"
131
+ out_name += f"{text_pormpt}".replace(' ', '_').replace(',', '').replace('.', '')
132
+
133
+ model_select_type = model_select.split('--')[1].strip()
134
+ model_select_type = model_select_type.lower().replace(' ', '_')
135
+
136
+ lora_path = f"./MotionDirector_pretrained/{model_select_type}/checkpoint-default/temporal/lora"
137
+ self.reload_lora(lora_path)
138
+ latents_folder = f"./MotionDirector_pretrained/{model_select_type}/cached_latents"
139
+ latents_path = f"{latents_folder}/{random.choice(os.listdir(latents_folder))}"
140
+ assert os.path.exists(lora_path)
141
+
142
+ if '3-' in model_select:
143
+ noise_prior = 0.
144
+ elif '2-' in model_select:
145
+ noise_prior = 0.5
146
+ else:
147
+ noise_prior = 0.3
148
+
149
+ if random_seed > 1000:
150
+ torch.manual_seed(random_seed)
151
+ else:
152
+ random_seed = random.randint(100, 10000000)
153
+ torch.manual_seed(random_seed)
154
+ device = "cuda"
155
+ with torch.autocast(device, dtype=torch.half):
156
+ # prepare input latents
157
+ with torch.no_grad():
158
+ init_latents,init_latents_base = prepare_input_latents(
159
+ pipe=self.pipe,
160
+ batch_size=1,
161
+ num_frames=16,
162
+ height=384,
163
+ width=384,
164
+ latents_path=latents_path,
165
+ noise_prior=noise_prior
166
+ )
167
+ video_frames = self.pipe(
168
+ prompt=text_pormpt,
169
+ negative_prompt=neg_text_pormpt,
170
+ width=384,
171
+ height=384,
172
+ num_frames=16,
173
+ num_inference_steps=steps,
174
+ guidance_scale=guidance_scale,
175
+ latents=init_latents
176
+ ).frames
177
+
178
+
179
+ out_file = f"{out_name}_{random_seed}.mp4"
180
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
181
+ export_to_video(video_frames, out_file, 8)
182
+
183
+ if baseline_select:
184
+ with torch.autocast("cuda", dtype=torch.half):
185
+
186
+ loras = extract_lora_child_module(self.pipe.unet, target_replace_module=["TransformerTemporalModel"])
187
+ for lora_i in loras:
188
+ lora_i.scale = 0.
189
+
190
+ # self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path,
191
+ # lora_scale=0.)
192
+ with torch.no_grad():
193
+ video_frames = self.pipe(
194
+ prompt=text_pormpt,
195
+ negative_prompt=neg_text_pormpt,
196
+ width=384,
197
+ height=384,
198
+ num_frames=16,
199
+ num_inference_steps=steps,
200
+ guidance_scale=guidance_scale,
201
+ latents=init_latents_base,
202
+ ).frames
203
+
204
+ out_file_baseline = f"{out_name}_{random_seed}_baseline.mp4"
205
+ os.makedirs(os.path.dirname(out_file_baseline), exist_ok=True)
206
+ export_to_video(video_frames, out_file_baseline, 8)
207
+ # with torch.autocast("cuda", dtype=torch.half):
208
+ # self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path,
209
+ # lora_scale=1.)
210
+ loras = extract_lora_child_module(self.pipe.unet,
211
+ target_replace_module=["TransformerTemporalModel"])
212
+ for lora_i in loras:
213
+ lora_i.scale = 1.
214
+
215
+ else:
216
+ out_file_baseline = None
217
+
218
+ return [out_file, out_file_baseline]
models/unet_3d_blocks.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.utils.checkpoint as checkpoint
17
+ from torch import nn
18
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ from diffusers.models.transformer_2d import Transformer2DModel
20
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
21
+
22
+ # Assign gradient checkpoint function to simple variable for readability.
23
+ g_c = checkpoint.checkpoint
24
+
25
+ def use_temporal(module, num_frames, x):
26
+ if num_frames == 1:
27
+ if isinstance(module, TransformerTemporalModel):
28
+ return {"sample": x}
29
+ else:
30
+ return x
31
+
32
+ def custom_checkpoint(module, mode=None):
33
+ if mode == None: raise ValueError('Mode for gradient checkpointing cannot be none.')
34
+ custom_forward = None
35
+
36
+ if mode == 'resnet':
37
+ def custom_forward(hidden_states, temb):
38
+ inputs = module(hidden_states, temb)
39
+ return inputs
40
+
41
+ if mode == 'attn':
42
+ def custom_forward(
43
+ hidden_states,
44
+ encoder_hidden_states=None,
45
+ cross_attention_kwargs=None
46
+ ):
47
+ inputs = module(
48
+ hidden_states,
49
+ encoder_hidden_states,
50
+ cross_attention_kwargs
51
+ )
52
+ return inputs
53
+
54
+ if mode == 'temp':
55
+ def custom_forward(hidden_states, num_frames=None):
56
+ inputs = use_temporal(module, num_frames, hidden_states)
57
+ if inputs is None: inputs = module(
58
+ hidden_states,
59
+ num_frames=num_frames
60
+ )
61
+ return inputs
62
+
63
+ return custom_forward
64
+
65
+ def transformer_g_c(transformer, sample, num_frames):
66
+ sample = g_c(custom_checkpoint(transformer, mode='temp'),
67
+ sample, num_frames, use_reentrant=False
68
+ )['sample']
69
+
70
+ return sample
71
+
72
+ def cross_attn_g_c(
73
+ attn,
74
+ temp_attn,
75
+ resnet,
76
+ temp_conv,
77
+ hidden_states,
78
+ encoder_hidden_states,
79
+ cross_attention_kwargs,
80
+ temb,
81
+ num_frames,
82
+ inverse_temp=False
83
+ ):
84
+
85
+ def ordered_g_c(idx):
86
+
87
+ # Self and CrossAttention
88
+ if idx == 0: return g_c(custom_checkpoint(attn, mode='attn'),
89
+ hidden_states, encoder_hidden_states,cross_attention_kwargs, use_reentrant=False
90
+ )['sample']
91
+
92
+ # Temporal Self and CrossAttention
93
+ if idx == 1: return g_c(custom_checkpoint(temp_attn, mode='temp'),
94
+ hidden_states, num_frames, use_reentrant=False)['sample']
95
+
96
+ # Resnets
97
+ if idx == 2: return g_c(custom_checkpoint(resnet, mode='resnet'),
98
+ hidden_states, temb, use_reentrant=False)
99
+
100
+ # Temporal Convolutions
101
+ if idx == 3: return g_c(custom_checkpoint(temp_conv, mode='temp'),
102
+ hidden_states, num_frames, use_reentrant=False
103
+ )
104
+
105
+ # Here we call the function depending on the order in which they are called.
106
+ # For some layers, the orders are different, so we access the appropriate one by index.
107
+
108
+ if not inverse_temp:
109
+ for idx in [0,1,2,3]: hidden_states = ordered_g_c(idx)
110
+ else:
111
+ for idx in [2,3,0,1]: hidden_states = ordered_g_c(idx)
112
+
113
+ return hidden_states
114
+
115
+ def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames):
116
+ hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), hidden_states, temb, use_reentrant=False)
117
+ hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'),
118
+ hidden_states, num_frames, use_reentrant=False
119
+ )
120
+ return hidden_states
121
+
122
+ def get_down_block(
123
+ down_block_type,
124
+ num_layers,
125
+ in_channels,
126
+ out_channels,
127
+ temb_channels,
128
+ add_downsample,
129
+ resnet_eps,
130
+ resnet_act_fn,
131
+ attn_num_head_channels,
132
+ resnet_groups=None,
133
+ cross_attention_dim=None,
134
+ downsample_padding=None,
135
+ dual_cross_attention=False,
136
+ use_linear_projection=True,
137
+ only_cross_attention=False,
138
+ upcast_attention=False,
139
+ resnet_time_scale_shift="default",
140
+ ):
141
+ if down_block_type == "DownBlock3D":
142
+ return DownBlock3D(
143
+ num_layers=num_layers,
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ temb_channels=temb_channels,
147
+ add_downsample=add_downsample,
148
+ resnet_eps=resnet_eps,
149
+ resnet_act_fn=resnet_act_fn,
150
+ resnet_groups=resnet_groups,
151
+ downsample_padding=downsample_padding,
152
+ resnet_time_scale_shift=resnet_time_scale_shift,
153
+ )
154
+ elif down_block_type == "CrossAttnDownBlock3D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
157
+ return CrossAttnDownBlock3D(
158
+ num_layers=num_layers,
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ temb_channels=temb_channels,
162
+ add_downsample=add_downsample,
163
+ resnet_eps=resnet_eps,
164
+ resnet_act_fn=resnet_act_fn,
165
+ resnet_groups=resnet_groups,
166
+ downsample_padding=downsample_padding,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attn_num_head_channels,
169
+ dual_cross_attention=dual_cross_attention,
170
+ use_linear_projection=use_linear_projection,
171
+ only_cross_attention=only_cross_attention,
172
+ upcast_attention=upcast_attention,
173
+ resnet_time_scale_shift=resnet_time_scale_shift,
174
+ )
175
+ raise ValueError(f"{down_block_type} does not exist.")
176
+
177
+
178
+ def get_up_block(
179
+ up_block_type,
180
+ num_layers,
181
+ in_channels,
182
+ out_channels,
183
+ prev_output_channel,
184
+ temb_channels,
185
+ add_upsample,
186
+ resnet_eps,
187
+ resnet_act_fn,
188
+ attn_num_head_channels,
189
+ resnet_groups=None,
190
+ cross_attention_dim=None,
191
+ dual_cross_attention=False,
192
+ use_linear_projection=True,
193
+ only_cross_attention=False,
194
+ upcast_attention=False,
195
+ resnet_time_scale_shift="default",
196
+ ):
197
+ if up_block_type == "UpBlock3D":
198
+ return UpBlock3D(
199
+ num_layers=num_layers,
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ prev_output_channel=prev_output_channel,
203
+ temb_channels=temb_channels,
204
+ add_upsample=add_upsample,
205
+ resnet_eps=resnet_eps,
206
+ resnet_act_fn=resnet_act_fn,
207
+ resnet_groups=resnet_groups,
208
+ resnet_time_scale_shift=resnet_time_scale_shift,
209
+ )
210
+ elif up_block_type == "CrossAttnUpBlock3D":
211
+ if cross_attention_dim is None:
212
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
213
+ return CrossAttnUpBlock3D(
214
+ num_layers=num_layers,
215
+ in_channels=in_channels,
216
+ out_channels=out_channels,
217
+ prev_output_channel=prev_output_channel,
218
+ temb_channels=temb_channels,
219
+ add_upsample=add_upsample,
220
+ resnet_eps=resnet_eps,
221
+ resnet_act_fn=resnet_act_fn,
222
+ resnet_groups=resnet_groups,
223
+ cross_attention_dim=cross_attention_dim,
224
+ attn_num_head_channels=attn_num_head_channels,
225
+ dual_cross_attention=dual_cross_attention,
226
+ use_linear_projection=use_linear_projection,
227
+ only_cross_attention=only_cross_attention,
228
+ upcast_attention=upcast_attention,
229
+ resnet_time_scale_shift=resnet_time_scale_shift,
230
+ )
231
+ raise ValueError(f"{up_block_type} does not exist.")
232
+
233
+
234
+ class UNetMidBlock3DCrossAttn(nn.Module):
235
+ def __init__(
236
+ self,
237
+ in_channels: int,
238
+ temb_channels: int,
239
+ dropout: float = 0.0,
240
+ num_layers: int = 1,
241
+ resnet_eps: float = 1e-6,
242
+ resnet_time_scale_shift: str = "default",
243
+ resnet_act_fn: str = "swish",
244
+ resnet_groups: int = 32,
245
+ resnet_pre_norm: bool = True,
246
+ attn_num_head_channels=1,
247
+ output_scale_factor=1.0,
248
+ cross_attention_dim=1280,
249
+ dual_cross_attention=False,
250
+ use_linear_projection=True,
251
+ upcast_attention=False,
252
+ ):
253
+ super().__init__()
254
+
255
+ self.gradient_checkpointing = False
256
+ self.has_cross_attention = True
257
+ self.attn_num_head_channels = attn_num_head_channels
258
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
259
+
260
+ # there is always at least one resnet
261
+ resnets = [
262
+ ResnetBlock2D(
263
+ in_channels=in_channels,
264
+ out_channels=in_channels,
265
+ temb_channels=temb_channels,
266
+ eps=resnet_eps,
267
+ groups=resnet_groups,
268
+ dropout=dropout,
269
+ time_embedding_norm=resnet_time_scale_shift,
270
+ non_linearity=resnet_act_fn,
271
+ output_scale_factor=output_scale_factor,
272
+ pre_norm=resnet_pre_norm,
273
+ )
274
+ ]
275
+ temp_convs = [
276
+ TemporalConvLayer(
277
+ in_channels,
278
+ in_channels,
279
+ dropout=0.1
280
+ )
281
+ ]
282
+ attentions = []
283
+ temp_attentions = []
284
+
285
+ for _ in range(num_layers):
286
+ attentions.append(
287
+ Transformer2DModel(
288
+ in_channels // attn_num_head_channels,
289
+ attn_num_head_channels,
290
+ in_channels=in_channels,
291
+ num_layers=1,
292
+ cross_attention_dim=cross_attention_dim,
293
+ norm_num_groups=resnet_groups,
294
+ use_linear_projection=use_linear_projection,
295
+ upcast_attention=upcast_attention,
296
+ )
297
+ )
298
+ temp_attentions.append(
299
+ TransformerTemporalModel(
300
+ in_channels // attn_num_head_channels,
301
+ attn_num_head_channels,
302
+ in_channels=in_channels,
303
+ num_layers=1,
304
+ cross_attention_dim=cross_attention_dim,
305
+ norm_num_groups=resnet_groups,
306
+ )
307
+ )
308
+ resnets.append(
309
+ ResnetBlock2D(
310
+ in_channels=in_channels,
311
+ out_channels=in_channels,
312
+ temb_channels=temb_channels,
313
+ eps=resnet_eps,
314
+ groups=resnet_groups,
315
+ dropout=dropout,
316
+ time_embedding_norm=resnet_time_scale_shift,
317
+ non_linearity=resnet_act_fn,
318
+ output_scale_factor=output_scale_factor,
319
+ pre_norm=resnet_pre_norm,
320
+ )
321
+ )
322
+ temp_convs.append(
323
+ TemporalConvLayer(
324
+ in_channels,
325
+ in_channels,
326
+ dropout=0.1
327
+ )
328
+ )
329
+
330
+ self.resnets = nn.ModuleList(resnets)
331
+ self.temp_convs = nn.ModuleList(temp_convs)
332
+ self.attentions = nn.ModuleList(attentions)
333
+ self.temp_attentions = nn.ModuleList(temp_attentions)
334
+
335
+ def forward(
336
+ self,
337
+ hidden_states,
338
+ temb=None,
339
+ encoder_hidden_states=None,
340
+ attention_mask=None,
341
+ num_frames=1,
342
+ cross_attention_kwargs=None,
343
+ ):
344
+ if self.gradient_checkpointing:
345
+ hidden_states = up_down_g_c(
346
+ self.resnets[0],
347
+ self.temp_convs[0],
348
+ hidden_states,
349
+ temb,
350
+ num_frames
351
+ )
352
+ else:
353
+ hidden_states = self.resnets[0](hidden_states, temb)
354
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
355
+
356
+ for attn, temp_attn, resnet, temp_conv in zip(
357
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
358
+ ):
359
+ if self.gradient_checkpointing:
360
+ hidden_states = cross_attn_g_c(
361
+ attn,
362
+ temp_attn,
363
+ resnet,
364
+ temp_conv,
365
+ hidden_states,
366
+ encoder_hidden_states,
367
+ cross_attention_kwargs,
368
+ temb,
369
+ num_frames
370
+ )
371
+ else:
372
+ hidden_states = attn(
373
+ hidden_states,
374
+ encoder_hidden_states=encoder_hidden_states,
375
+ cross_attention_kwargs=cross_attention_kwargs,
376
+ ).sample
377
+
378
+ if num_frames > 1:
379
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
380
+
381
+ hidden_states = resnet(hidden_states, temb)
382
+
383
+ if num_frames > 1:
384
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
385
+
386
+ return hidden_states
387
+
388
+
389
+ class CrossAttnDownBlock3D(nn.Module):
390
+ def __init__(
391
+ self,
392
+ in_channels: int,
393
+ out_channels: int,
394
+ temb_channels: int,
395
+ dropout: float = 0.0,
396
+ num_layers: int = 1,
397
+ resnet_eps: float = 1e-6,
398
+ resnet_time_scale_shift: str = "default",
399
+ resnet_act_fn: str = "swish",
400
+ resnet_groups: int = 32,
401
+ resnet_pre_norm: bool = True,
402
+ attn_num_head_channels=1,
403
+ cross_attention_dim=1280,
404
+ output_scale_factor=1.0,
405
+ downsample_padding=1,
406
+ add_downsample=True,
407
+ dual_cross_attention=False,
408
+ use_linear_projection=False,
409
+ only_cross_attention=False,
410
+ upcast_attention=False,
411
+ ):
412
+ super().__init__()
413
+ resnets = []
414
+ attentions = []
415
+ temp_attentions = []
416
+ temp_convs = []
417
+
418
+ self.gradient_checkpointing = False
419
+ self.has_cross_attention = True
420
+ self.attn_num_head_channels = attn_num_head_channels
421
+
422
+ for i in range(num_layers):
423
+ in_channels = in_channels if i == 0 else out_channels
424
+ resnets.append(
425
+ ResnetBlock2D(
426
+ in_channels=in_channels,
427
+ out_channels=out_channels,
428
+ temb_channels=temb_channels,
429
+ eps=resnet_eps,
430
+ groups=resnet_groups,
431
+ dropout=dropout,
432
+ time_embedding_norm=resnet_time_scale_shift,
433
+ non_linearity=resnet_act_fn,
434
+ output_scale_factor=output_scale_factor,
435
+ pre_norm=resnet_pre_norm,
436
+ )
437
+ )
438
+ temp_convs.append(
439
+ TemporalConvLayer(
440
+ out_channels,
441
+ out_channels,
442
+ dropout=0.1
443
+ )
444
+ )
445
+ attentions.append(
446
+ Transformer2DModel(
447
+ out_channels // attn_num_head_channels,
448
+ attn_num_head_channels,
449
+ in_channels=out_channels,
450
+ num_layers=1,
451
+ cross_attention_dim=cross_attention_dim,
452
+ norm_num_groups=resnet_groups,
453
+ use_linear_projection=use_linear_projection,
454
+ only_cross_attention=only_cross_attention,
455
+ upcast_attention=upcast_attention,
456
+ )
457
+ )
458
+ temp_attentions.append(
459
+ TransformerTemporalModel(
460
+ out_channels // attn_num_head_channels,
461
+ attn_num_head_channels,
462
+ in_channels=out_channels,
463
+ num_layers=1,
464
+ cross_attention_dim=cross_attention_dim,
465
+ norm_num_groups=resnet_groups,
466
+ )
467
+ )
468
+ self.resnets = nn.ModuleList(resnets)
469
+ self.temp_convs = nn.ModuleList(temp_convs)
470
+ self.attentions = nn.ModuleList(attentions)
471
+ self.temp_attentions = nn.ModuleList(temp_attentions)
472
+
473
+ if add_downsample:
474
+ self.downsamplers = nn.ModuleList(
475
+ [
476
+ Downsample2D(
477
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
478
+ )
479
+ ]
480
+ )
481
+ else:
482
+ self.downsamplers = None
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states,
487
+ temb=None,
488
+ encoder_hidden_states=None,
489
+ attention_mask=None,
490
+ num_frames=1,
491
+ cross_attention_kwargs=None,
492
+ ):
493
+ # TODO(Patrick, William) - attention mask is not used
494
+ output_states = ()
495
+
496
+ for resnet, temp_conv, attn, temp_attn in zip(
497
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
498
+ ):
499
+
500
+ if self.gradient_checkpointing:
501
+ hidden_states = cross_attn_g_c(
502
+ attn,
503
+ temp_attn,
504
+ resnet,
505
+ temp_conv,
506
+ hidden_states,
507
+ encoder_hidden_states,
508
+ cross_attention_kwargs,
509
+ temb,
510
+ num_frames,
511
+ inverse_temp=True
512
+ )
513
+ else:
514
+ hidden_states = resnet(hidden_states, temb)
515
+
516
+ if num_frames > 1:
517
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
518
+
519
+ hidden_states = attn(
520
+ hidden_states,
521
+ encoder_hidden_states=encoder_hidden_states,
522
+ cross_attention_kwargs=cross_attention_kwargs,
523
+ ).sample
524
+
525
+ if num_frames > 1:
526
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
527
+
528
+ output_states += (hidden_states,)
529
+
530
+ if self.downsamplers is not None:
531
+ for downsampler in self.downsamplers:
532
+ hidden_states = downsampler(hidden_states)
533
+
534
+ output_states += (hidden_states,)
535
+
536
+ return hidden_states, output_states
537
+
538
+
539
+ class DownBlock3D(nn.Module):
540
+ def __init__(
541
+ self,
542
+ in_channels: int,
543
+ out_channels: int,
544
+ temb_channels: int,
545
+ dropout: float = 0.0,
546
+ num_layers: int = 1,
547
+ resnet_eps: float = 1e-6,
548
+ resnet_time_scale_shift: str = "default",
549
+ resnet_act_fn: str = "swish",
550
+ resnet_groups: int = 32,
551
+ resnet_pre_norm: bool = True,
552
+ output_scale_factor=1.0,
553
+ add_downsample=True,
554
+ downsample_padding=1,
555
+ ):
556
+ super().__init__()
557
+ resnets = []
558
+ temp_convs = []
559
+
560
+ self.gradient_checkpointing = False
561
+ for i in range(num_layers):
562
+ in_channels = in_channels if i == 0 else out_channels
563
+ resnets.append(
564
+ ResnetBlock2D(
565
+ in_channels=in_channels,
566
+ out_channels=out_channels,
567
+ temb_channels=temb_channels,
568
+ eps=resnet_eps,
569
+ groups=resnet_groups,
570
+ dropout=dropout,
571
+ time_embedding_norm=resnet_time_scale_shift,
572
+ non_linearity=resnet_act_fn,
573
+ output_scale_factor=output_scale_factor,
574
+ pre_norm=resnet_pre_norm,
575
+ )
576
+ )
577
+ temp_convs.append(
578
+ TemporalConvLayer(
579
+ out_channels,
580
+ out_channels,
581
+ dropout=0.1
582
+ )
583
+ )
584
+
585
+ self.resnets = nn.ModuleList(resnets)
586
+ self.temp_convs = nn.ModuleList(temp_convs)
587
+
588
+ if add_downsample:
589
+ self.downsamplers = nn.ModuleList(
590
+ [
591
+ Downsample2D(
592
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
593
+ )
594
+ ]
595
+ )
596
+ else:
597
+ self.downsamplers = None
598
+
599
+ def forward(self, hidden_states, temb=None, num_frames=1):
600
+ output_states = ()
601
+
602
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
603
+ if self.gradient_checkpointing:
604
+ hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
605
+ else:
606
+ hidden_states = resnet(hidden_states, temb)
607
+
608
+ if num_frames > 1:
609
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
610
+
611
+ output_states += (hidden_states,)
612
+
613
+ if self.downsamplers is not None:
614
+ for downsampler in self.downsamplers:
615
+ hidden_states = downsampler(hidden_states)
616
+
617
+ output_states += (hidden_states,)
618
+
619
+ return hidden_states, output_states
620
+
621
+
622
+ class CrossAttnUpBlock3D(nn.Module):
623
+ def __init__(
624
+ self,
625
+ in_channels: int,
626
+ out_channels: int,
627
+ prev_output_channel: int,
628
+ temb_channels: int,
629
+ dropout: float = 0.0,
630
+ num_layers: int = 1,
631
+ resnet_eps: float = 1e-6,
632
+ resnet_time_scale_shift: str = "default",
633
+ resnet_act_fn: str = "swish",
634
+ resnet_groups: int = 32,
635
+ resnet_pre_norm: bool = True,
636
+ attn_num_head_channels=1,
637
+ cross_attention_dim=1280,
638
+ output_scale_factor=1.0,
639
+ add_upsample=True,
640
+ dual_cross_attention=False,
641
+ use_linear_projection=False,
642
+ only_cross_attention=False,
643
+ upcast_attention=False,
644
+ ):
645
+ super().__init__()
646
+ resnets = []
647
+ temp_convs = []
648
+ attentions = []
649
+ temp_attentions = []
650
+
651
+ self.gradient_checkpointing = False
652
+ self.has_cross_attention = True
653
+ self.attn_num_head_channels = attn_num_head_channels
654
+
655
+ for i in range(num_layers):
656
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
657
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
658
+
659
+ resnets.append(
660
+ ResnetBlock2D(
661
+ in_channels=resnet_in_channels + res_skip_channels,
662
+ out_channels=out_channels,
663
+ temb_channels=temb_channels,
664
+ eps=resnet_eps,
665
+ groups=resnet_groups,
666
+ dropout=dropout,
667
+ time_embedding_norm=resnet_time_scale_shift,
668
+ non_linearity=resnet_act_fn,
669
+ output_scale_factor=output_scale_factor,
670
+ pre_norm=resnet_pre_norm,
671
+ )
672
+ )
673
+ temp_convs.append(
674
+ TemporalConvLayer(
675
+ out_channels,
676
+ out_channels,
677
+ dropout=0.1
678
+ )
679
+ )
680
+ attentions.append(
681
+ Transformer2DModel(
682
+ out_channels // attn_num_head_channels,
683
+ attn_num_head_channels,
684
+ in_channels=out_channels,
685
+ num_layers=1,
686
+ cross_attention_dim=cross_attention_dim,
687
+ norm_num_groups=resnet_groups,
688
+ use_linear_projection=use_linear_projection,
689
+ only_cross_attention=only_cross_attention,
690
+ upcast_attention=upcast_attention,
691
+ )
692
+ )
693
+ temp_attentions.append(
694
+ TransformerTemporalModel(
695
+ out_channels // attn_num_head_channels,
696
+ attn_num_head_channels,
697
+ in_channels=out_channels,
698
+ num_layers=1,
699
+ cross_attention_dim=cross_attention_dim,
700
+ norm_num_groups=resnet_groups,
701
+ )
702
+ )
703
+ self.resnets = nn.ModuleList(resnets)
704
+ self.temp_convs = nn.ModuleList(temp_convs)
705
+ self.attentions = nn.ModuleList(attentions)
706
+ self.temp_attentions = nn.ModuleList(temp_attentions)
707
+
708
+ if add_upsample:
709
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
710
+ else:
711
+ self.upsamplers = None
712
+
713
+ def forward(
714
+ self,
715
+ hidden_states,
716
+ res_hidden_states_tuple,
717
+ temb=None,
718
+ encoder_hidden_states=None,
719
+ upsample_size=None,
720
+ attention_mask=None,
721
+ num_frames=1,
722
+ cross_attention_kwargs=None,
723
+ ):
724
+ # TODO(Patrick, William) - attention mask is not used
725
+ for resnet, temp_conv, attn, temp_attn in zip(
726
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
727
+ ):
728
+ # pop res hidden states
729
+ res_hidden_states = res_hidden_states_tuple[-1]
730
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
731
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
732
+
733
+ if self.gradient_checkpointing:
734
+ hidden_states = cross_attn_g_c(
735
+ attn,
736
+ temp_attn,
737
+ resnet,
738
+ temp_conv,
739
+ hidden_states,
740
+ encoder_hidden_states,
741
+ cross_attention_kwargs,
742
+ temb,
743
+ num_frames,
744
+ inverse_temp=True
745
+ )
746
+ else:
747
+ hidden_states = resnet(hidden_states, temb)
748
+
749
+ if num_frames > 1:
750
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
751
+
752
+ hidden_states = attn(
753
+ hidden_states,
754
+ encoder_hidden_states=encoder_hidden_states,
755
+ cross_attention_kwargs=cross_attention_kwargs,
756
+ ).sample
757
+
758
+ if num_frames > 1:
759
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
760
+
761
+ if self.upsamplers is not None:
762
+ for upsampler in self.upsamplers:
763
+ hidden_states = upsampler(hidden_states, upsample_size)
764
+
765
+ return hidden_states
766
+
767
+
768
+ class UpBlock3D(nn.Module):
769
+ def __init__(
770
+ self,
771
+ in_channels: int,
772
+ prev_output_channel: int,
773
+ out_channels: int,
774
+ temb_channels: int,
775
+ dropout: float = 0.0,
776
+ num_layers: int = 1,
777
+ resnet_eps: float = 1e-6,
778
+ resnet_time_scale_shift: str = "default",
779
+ resnet_act_fn: str = "swish",
780
+ resnet_groups: int = 32,
781
+ resnet_pre_norm: bool = True,
782
+ output_scale_factor=1.0,
783
+ add_upsample=True,
784
+ ):
785
+ super().__init__()
786
+ resnets = []
787
+ temp_convs = []
788
+ self.gradient_checkpointing = False
789
+ for i in range(num_layers):
790
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
791
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
792
+
793
+ resnets.append(
794
+ ResnetBlock2D(
795
+ in_channels=resnet_in_channels + res_skip_channels,
796
+ out_channels=out_channels,
797
+ temb_channels=temb_channels,
798
+ eps=resnet_eps,
799
+ groups=resnet_groups,
800
+ dropout=dropout,
801
+ time_embedding_norm=resnet_time_scale_shift,
802
+ non_linearity=resnet_act_fn,
803
+ output_scale_factor=output_scale_factor,
804
+ pre_norm=resnet_pre_norm,
805
+ )
806
+ )
807
+ temp_convs.append(
808
+ TemporalConvLayer(
809
+ out_channels,
810
+ out_channels,
811
+ dropout=0.1
812
+ )
813
+ )
814
+
815
+ self.resnets = nn.ModuleList(resnets)
816
+ self.temp_convs = nn.ModuleList(temp_convs)
817
+
818
+ if add_upsample:
819
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
820
+ else:
821
+ self.upsamplers = None
822
+
823
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
824
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
825
+ # pop res hidden states
826
+ res_hidden_states = res_hidden_states_tuple[-1]
827
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
828
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
829
+
830
+ if self.gradient_checkpointing:
831
+ hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
832
+ else:
833
+ hidden_states = resnet(hidden_states, temb)
834
+
835
+ if num_frames > 1:
836
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
837
+
838
+ if self.upsamplers is not None:
839
+ for upsampler in self.upsamplers:
840
+ hidden_states = upsampler(hidden_states, upsample_size)
841
+
842
+ return hidden_states
models/unet_3d_condition.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
27
+ from .unet_3d_blocks import (
28
+ CrossAttnDownBlock3D,
29
+ CrossAttnUpBlock3D,
30
+ DownBlock3D,
31
+ UNetMidBlock3DCrossAttn,
32
+ UpBlock3D,
33
+ get_down_block,
34
+ get_up_block,
35
+ transformer_g_c
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ @dataclass
43
+ class UNet3DConditionOutput(BaseOutput):
44
+ """
45
+ Args:
46
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
47
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
48
+ """
49
+
50
+ sample: torch.FloatTensor
51
+
52
+
53
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
54
+ r"""
55
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
56
+ and returns sample shaped output.
57
+
58
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
+ implements for all the models (such as downloading or saving, etc.)
60
+
61
+ Parameters:
62
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
63
+ Height and width of input/output sample.
64
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
65
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
66
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
67
+ The tuple of downsample blocks to use.
68
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
69
+ The tuple of upsample blocks to use.
70
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
71
+ The tuple of output channels for each block.
72
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
73
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
74
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
75
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
76
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
77
+ If `None`, it will skip the normalization and activation layers in post-processing
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ down_block_types: Tuple[str] = (
92
+ "CrossAttnDownBlock3D",
93
+ "CrossAttnDownBlock3D",
94
+ "CrossAttnDownBlock3D",
95
+ "DownBlock3D",
96
+ ),
97
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
98
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
99
+ layers_per_block: int = 2,
100
+ downsample_padding: int = 1,
101
+ mid_block_scale_factor: float = 1,
102
+ act_fn: str = "silu",
103
+ norm_num_groups: Optional[int] = 32,
104
+ norm_eps: float = 1e-5,
105
+ cross_attention_dim: int = 1024,
106
+ attention_head_dim: Union[int, Tuple[int]] = 64,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.sample_size = sample_size
111
+ self.gradient_checkpointing = False
112
+ # Check inputs
113
+ if len(down_block_types) != len(up_block_types):
114
+ raise ValueError(
115
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
116
+ )
117
+
118
+ if len(block_out_channels) != len(down_block_types):
119
+ raise ValueError(
120
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
121
+ )
122
+
123
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
124
+ raise ValueError(
125
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
126
+ )
127
+
128
+ # input
129
+ conv_in_kernel = 3
130
+ conv_out_kernel = 3
131
+ conv_in_padding = (conv_in_kernel - 1) // 2
132
+ self.conv_in = nn.Conv2d(
133
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
134
+ )
135
+
136
+ # time
137
+ time_embed_dim = block_out_channels[0] * 4
138
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
139
+ timestep_input_dim = block_out_channels[0]
140
+
141
+ self.time_embedding = TimestepEmbedding(
142
+ timestep_input_dim,
143
+ time_embed_dim,
144
+ act_fn=act_fn,
145
+ )
146
+
147
+ self.transformer_in = TransformerTemporalModel(
148
+ num_attention_heads=8,
149
+ attention_head_dim=attention_head_dim,
150
+ in_channels=block_out_channels[0],
151
+ num_layers=1,
152
+ )
153
+
154
+ # class embedding
155
+ self.down_blocks = nn.ModuleList([])
156
+ self.up_blocks = nn.ModuleList([])
157
+
158
+ if isinstance(attention_head_dim, int):
159
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
160
+
161
+ # down
162
+ output_channel = block_out_channels[0]
163
+ for i, down_block_type in enumerate(down_block_types):
164
+ input_channel = output_channel
165
+ output_channel = block_out_channels[i]
166
+ is_final_block = i == len(block_out_channels) - 1
167
+
168
+ down_block = get_down_block(
169
+ down_block_type,
170
+ num_layers=layers_per_block,
171
+ in_channels=input_channel,
172
+ out_channels=output_channel,
173
+ temb_channels=time_embed_dim,
174
+ add_downsample=not is_final_block,
175
+ resnet_eps=norm_eps,
176
+ resnet_act_fn=act_fn,
177
+ resnet_groups=norm_num_groups,
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attention_head_dim[i],
180
+ downsample_padding=downsample_padding,
181
+ dual_cross_attention=False,
182
+ )
183
+ self.down_blocks.append(down_block)
184
+
185
+ # mid
186
+ self.mid_block = UNetMidBlock3DCrossAttn(
187
+ in_channels=block_out_channels[-1],
188
+ temb_channels=time_embed_dim,
189
+ resnet_eps=norm_eps,
190
+ resnet_act_fn=act_fn,
191
+ output_scale_factor=mid_block_scale_factor,
192
+ cross_attention_dim=cross_attention_dim,
193
+ attn_num_head_channels=attention_head_dim[-1],
194
+ resnet_groups=norm_num_groups,
195
+ dual_cross_attention=False,
196
+ )
197
+
198
+ # count how many layers upsample the images
199
+ self.num_upsamplers = 0
200
+
201
+ # up
202
+ reversed_block_out_channels = list(reversed(block_out_channels))
203
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
204
+
205
+ output_channel = reversed_block_out_channels[0]
206
+ for i, up_block_type in enumerate(up_block_types):
207
+ is_final_block = i == len(block_out_channels) - 1
208
+
209
+ prev_output_channel = output_channel
210
+ output_channel = reversed_block_out_channels[i]
211
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
212
+
213
+ # add upsample block for all BUT final layer
214
+ if not is_final_block:
215
+ add_upsample = True
216
+ self.num_upsamplers += 1
217
+ else:
218
+ add_upsample = False
219
+
220
+ up_block = get_up_block(
221
+ up_block_type,
222
+ num_layers=layers_per_block + 1,
223
+ in_channels=input_channel,
224
+ out_channels=output_channel,
225
+ prev_output_channel=prev_output_channel,
226
+ temb_channels=time_embed_dim,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=norm_eps,
229
+ resnet_act_fn=act_fn,
230
+ resnet_groups=norm_num_groups,
231
+ cross_attention_dim=cross_attention_dim,
232
+ attn_num_head_channels=reversed_attention_head_dim[i],
233
+ dual_cross_attention=False,
234
+ )
235
+ self.up_blocks.append(up_block)
236
+ prev_output_channel = output_channel
237
+
238
+ # out
239
+ if norm_num_groups is not None:
240
+ self.conv_norm_out = nn.GroupNorm(
241
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
242
+ )
243
+ self.conv_act = nn.SiLU()
244
+ else:
245
+ self.conv_norm_out = None
246
+ self.conv_act = None
247
+
248
+ conv_out_padding = (conv_out_kernel - 1) // 2
249
+ self.conv_out = nn.Conv2d(
250
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
251
+ )
252
+
253
+ def set_attention_slice(self, slice_size):
254
+ r"""
255
+ Enable sliced attention computation.
256
+
257
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
258
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
259
+
260
+ Args:
261
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
262
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
263
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
264
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
265
+ must be a multiple of `slice_size`.
266
+ """
267
+ sliceable_head_dims = []
268
+
269
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
270
+ if hasattr(module, "set_attention_slice"):
271
+ sliceable_head_dims.append(module.sliceable_head_dim)
272
+
273
+ for child in module.children():
274
+ fn_recursive_retrieve_slicable_dims(child)
275
+
276
+ # retrieve number of attention layers
277
+ for module in self.children():
278
+ fn_recursive_retrieve_slicable_dims(module)
279
+
280
+ num_slicable_layers = len(sliceable_head_dims)
281
+
282
+ if slice_size == "auto":
283
+ # half the attention head size is usually a good trade-off between
284
+ # speed and memory
285
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
286
+ elif slice_size == "max":
287
+ # make smallest slice possible
288
+ slice_size = num_slicable_layers * [1]
289
+
290
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
291
+
292
+ if len(slice_size) != len(sliceable_head_dims):
293
+ raise ValueError(
294
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
295
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
296
+ )
297
+
298
+ for i in range(len(slice_size)):
299
+ size = slice_size[i]
300
+ dim = sliceable_head_dims[i]
301
+ if size is not None and size > dim:
302
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
303
+
304
+ # Recursively walk through all the children.
305
+ # Any children which exposes the set_attention_slice method
306
+ # gets the message
307
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
308
+ if hasattr(module, "set_attention_slice"):
309
+ module.set_attention_slice(slice_size.pop())
310
+
311
+ for child in module.children():
312
+ fn_recursive_set_attention_slice(child, slice_size)
313
+
314
+ reversed_slice_size = list(reversed(slice_size))
315
+ for module in self.children():
316
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
317
+
318
+ def _set_gradient_checkpointing(self, value=False):
319
+ self.gradient_checkpointing = value
320
+ self.mid_block.gradient_checkpointing = value
321
+ for module in self.down_blocks + self.up_blocks:
322
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
323
+ module.gradient_checkpointing = value
324
+
325
+ def forward(
326
+ self,
327
+ sample: torch.FloatTensor,
328
+ timestep: Union[torch.Tensor, float, int],
329
+ encoder_hidden_states: torch.Tensor,
330
+ class_labels: Optional[torch.Tensor] = None,
331
+ timestep_cond: Optional[torch.Tensor] = None,
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
334
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
335
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
336
+ return_dict: bool = True,
337
+ ) -> Union[UNet3DConditionOutput, Tuple]:
338
+ r"""
339
+ Args:
340
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
341
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
342
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
345
+ cross_attention_kwargs (`dict`, *optional*):
346
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
347
+ `self.processor` in
348
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
349
+
350
+ Returns:
351
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
352
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
353
+ returning a tuple, the first element is the sample tensor.
354
+ """
355
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
356
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
357
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
358
+ # on the fly if necessary.
359
+ default_overall_up_factor = 2**self.num_upsamplers
360
+
361
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
362
+ forward_upsample_size = False
363
+ upsample_size = None
364
+
365
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
366
+ logger.info("Forward upsample size to force interpolation output size.")
367
+ forward_upsample_size = True
368
+
369
+ # prepare attention_mask
370
+ if attention_mask is not None:
371
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
372
+ attention_mask = attention_mask.unsqueeze(1)
373
+
374
+ # 1. time
375
+ timesteps = timestep
376
+ if not torch.is_tensor(timesteps):
377
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
378
+ # This would be a good case for the `match` statement (Python 3.10+)
379
+ is_mps = sample.device.type == "mps"
380
+ if isinstance(timestep, float):
381
+ dtype = torch.float32 if is_mps else torch.float64
382
+ else:
383
+ dtype = torch.int32 if is_mps else torch.int64
384
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
385
+ elif len(timesteps.shape) == 0:
386
+ timesteps = timesteps[None].to(sample.device)
387
+
388
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
389
+ num_frames = sample.shape[2]
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+
399
+ emb = self.time_embedding(t_emb, timestep_cond)
400
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
401
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
402
+
403
+ # 2. pre-process
404
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
405
+ sample = self.conv_in(sample)
406
+
407
+ if num_frames > 1:
408
+ if self.gradient_checkpointing:
409
+ sample = transformer_g_c(self.transformer_in, sample, num_frames)
410
+ else:
411
+ sample = self.transformer_in(sample, num_frames=num_frames).sample
412
+
413
+ # 3. down
414
+ down_block_res_samples = (sample,)
415
+ for downsample_block in self.down_blocks:
416
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
417
+ sample, res_samples = downsample_block(
418
+ hidden_states=sample,
419
+ temb=emb,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ num_frames=num_frames,
423
+ cross_attention_kwargs=cross_attention_kwargs,
424
+ )
425
+ else:
426
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
427
+
428
+ down_block_res_samples += res_samples
429
+
430
+ if down_block_additional_residuals is not None:
431
+ new_down_block_res_samples = ()
432
+
433
+ for down_block_res_sample, down_block_additional_residual in zip(
434
+ down_block_res_samples, down_block_additional_residuals
435
+ ):
436
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
437
+ new_down_block_res_samples += (down_block_res_sample,)
438
+
439
+ down_block_res_samples = new_down_block_res_samples
440
+
441
+ # 4. mid
442
+ if self.mid_block is not None:
443
+ sample = self.mid_block(
444
+ sample,
445
+ emb,
446
+ encoder_hidden_states=encoder_hidden_states,
447
+ attention_mask=attention_mask,
448
+ num_frames=num_frames,
449
+ cross_attention_kwargs=cross_attention_kwargs,
450
+ )
451
+
452
+ if mid_block_additional_residual is not None:
453
+ sample = sample + mid_block_additional_residual
454
+
455
+ # 5. up
456
+ for i, upsample_block in enumerate(self.up_blocks):
457
+ is_final_block = i == len(self.up_blocks) - 1
458
+
459
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
460
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
461
+
462
+ # if we have not reached the final block and need to forward the
463
+ # upsample size, we do it here
464
+ if not is_final_block and forward_upsample_size:
465
+ upsample_size = down_block_res_samples[-1].shape[2:]
466
+
467
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
468
+ sample = upsample_block(
469
+ hidden_states=sample,
470
+ temb=emb,
471
+ res_hidden_states_tuple=res_samples,
472
+ encoder_hidden_states=encoder_hidden_states,
473
+ upsample_size=upsample_size,
474
+ attention_mask=attention_mask,
475
+ num_frames=num_frames,
476
+ cross_attention_kwargs=cross_attention_kwargs,
477
+ )
478
+ else:
479
+ sample = upsample_block(
480
+ hidden_states=sample,
481
+ temb=emb,
482
+ res_hidden_states_tuple=res_samples,
483
+ upsample_size=upsample_size,
484
+ num_frames=num_frames,
485
+ )
486
+
487
+ # 6. post-process
488
+ if self.conv_norm_out:
489
+ sample = self.conv_norm_out(sample)
490
+ sample = self.conv_act(sample)
491
+
492
+ sample = self.conv_out(sample)
493
+
494
+ # reshape to (batch, channel, framerate, width, height)
495
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
496
+
497
+ if not return_dict:
498
+ return (sample,)
499
+
500
+ return UNet3DConditionOutput(sample=sample)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.18.0
2
+ decord==0.6.0
3
+ deepspeed==0.10.0
4
+ diffusers==0.18.0
5
+ huggingface-hub==0.16.4
6
+ lora-diffusion @ git+https://github.com/cloneofsimo/lora.git@bdd51b04c49fa90a88919a19850ec3b4cf3c5ecd
7
+ loralib==0.1.0
8
+ numpy==1.24.0
9
+ omegaconf==2.3.0
10
+ opencv-python==4.8.0.74
11
+ torch==2.0.1
12
+ torchaudio==2.0.2
13
+ torchvision==0.15.2
14
+ tqdm==4.65.0
15
+ transformers==4.27.4
16
+ einops==0.7.0
17
+ imageio==2.33.0
18
+ imageio-ffmpeg==0.4.9
19
+ gradio==3.26.0
utils/bucketing.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ def min_res(size, min_size): return 192 if size < 192 else size
4
+
5
+ def up_down_bucket(m_size, in_size, direction):
6
+ if direction == 'down': return abs(int(m_size - in_size))
7
+ if direction == 'up': return abs(int(m_size + in_size))
8
+
9
+ def get_bucket_sizes(size, direction: 'down', min_size):
10
+ multipliers = [64, 128]
11
+ for i, m in enumerate(multipliers):
12
+ res = up_down_bucket(m, size, direction)
13
+ multipliers[i] = min_res(res, min_size=min_size)
14
+ return multipliers
15
+
16
+ def closest_bucket(m_size, size, direction, min_size):
17
+ lst = get_bucket_sizes(m_size, direction, min_size)
18
+ return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))]
19
+
20
+ def resolve_bucket(i,h,w): return (i / (h / w))
21
+
22
+ def sensible_buckets(m_width, m_height, w, h, min_size=192):
23
+ if h > w:
24
+ w = resolve_bucket(m_width, h, w)
25
+ w = closest_bucket(m_width, w, 'down', min_size=min_size)
26
+ return w, m_height
27
+ if h < w:
28
+ h = resolve_bucket(m_height, w, h)
29
+ h = closest_bucket(m_height, h, 'down', min_size=min_size)
30
+ return m_width, h
31
+
32
+ return m_width, m_height
utils/convert_diffusers_to_original_ms_text_to_video.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+
12
+ # =================#
13
+ # UNet Conversion #
14
+ # =================#
15
+
16
+ print ('Initializing the conversion map')
17
+
18
+ unet_conversion_map = [
19
+ # (ModelScope, HF Diffusers)
20
+
21
+ # from Vanilla ModelScope/StableDiffusion
22
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
23
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
24
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
25
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
26
+
27
+
28
+ # from Vanilla ModelScope/StableDiffusion
29
+ ("input_blocks.0.0.weight", "conv_in.weight"),
30
+ ("input_blocks.0.0.bias", "conv_in.bias"),
31
+
32
+
33
+ # from Vanilla ModelScope/StableDiffusion
34
+ ("out.0.weight", "conv_norm_out.weight"),
35
+ ("out.0.bias", "conv_norm_out.bias"),
36
+ ("out.2.weight", "conv_out.weight"),
37
+ ("out.2.bias", "conv_out.bias"),
38
+ ]
39
+
40
+ unet_conversion_map_resnet = [
41
+ # (ModelScope, HF Diffusers)
42
+
43
+ # SD
44
+ ("in_layers.0", "norm1"),
45
+ ("in_layers.2", "conv1"),
46
+ ("out_layers.0", "norm2"),
47
+ ("out_layers.3", "conv2"),
48
+ ("emb_layers.1", "time_emb_proj"),
49
+ ("skip_connection", "conv_shortcut"),
50
+
51
+ # MS
52
+ #("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha
53
+ ]
54
+
55
+ unet_conversion_map_layer = []
56
+
57
+ # Convert input TemporalTransformer
58
+ unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in'))
59
+
60
+ # Reference for the default settings
61
+
62
+ # "model_cfg": {
63
+ # "unet_in_dim": 4,
64
+ # "unet_dim": 320,
65
+ # "unet_y_dim": 768,
66
+ # "unet_context_dim": 1024,
67
+ # "unet_out_dim": 4,
68
+ # "unet_dim_mult": [1, 2, 4, 4],
69
+ # "unet_num_heads": 8,
70
+ # "unet_head_dim": 64,
71
+ # "unet_res_blocks": 2,
72
+ # "unet_attn_scales": [1, 0.5, 0.25],
73
+ # "unet_dropout": 0.1,
74
+ # "temporal_attention": "True",
75
+ # "num_timesteps": 1000,
76
+ # "mean_type": "eps",
77
+ # "var_type": "fixed_small",
78
+ # "loss_type": "mse"
79
+ # }
80
+
81
+ # hardcoded number of downblocks and resnets/attentions...
82
+ # would need smarter logic for other networks.
83
+ for i in range(4):
84
+ # loop over downblocks/upblocks
85
+
86
+ for j in range(2):
87
+ # loop over resnets/attentions for downblocks
88
+
89
+ # Spacial SD stuff
90
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
91
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
92
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
93
+
94
+ if i < 3:
95
+ # no attention layers in down_blocks.3
96
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
97
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
98
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
99
+
100
+ # Temporal MS stuff
101
+ hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}."
102
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv."
103
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
104
+
105
+ if i < 3:
106
+ # no attention layers in down_blocks.3
107
+ hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}."
108
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2."
109
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
110
+
111
+ for j in range(3):
112
+ # loop over resnets/attentions for upblocks
113
+
114
+ # Spacial SD stuff
115
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
116
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
117
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
118
+
119
+ if i > 0:
120
+ # no attention layers in up_blocks.0
121
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
122
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
123
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
124
+
125
+ # loop over resnets/attentions for upblocks
126
+ hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}."
127
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv."
128
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
129
+
130
+ if i > 0:
131
+ # no attention layers in up_blocks.0
132
+ hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}."
133
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.2."
134
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
135
+
136
+ # Up/Downsamplers are 2D, so don't need to touch them
137
+ if i < 3:
138
+ # no downsample in down_blocks.3
139
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
140
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.op."
141
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
142
+
143
+ # no upsample in up_blocks.3
144
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
145
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 3}."
146
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
147
+
148
+
149
+ # Handle the middle block
150
+
151
+ # Spacial
152
+ hf_mid_atn_prefix = "mid_block.attentions.0."
153
+ sd_mid_atn_prefix = "middle_block.1."
154
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
155
+
156
+ for j in range(2):
157
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
158
+ sd_mid_res_prefix = f"middle_block.{3*j}."
159
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
160
+
161
+ # Temporal
162
+ hf_mid_atn_prefix = "mid_block.temp_attentions.0."
163
+ sd_mid_atn_prefix = "middle_block.2."
164
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
165
+
166
+ for j in range(2):
167
+ hf_mid_res_prefix = f"mid_block.temp_convs.{j}."
168
+ sd_mid_res_prefix = f"middle_block.{3*j}.temopral_conv."
169
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
170
+
171
+ # The pipeline
172
+ def convert_unet_state_dict(unet_state_dict, strict_mapping=False):
173
+ print ('Converting the UNET')
174
+ # buyer beware: this is a *brittle* function,
175
+ # and correct output requires that all of these pieces interact in
176
+ # the exact order in which I have arranged them.
177
+ mapping = {k: k for k in unet_state_dict.keys()}
178
+
179
+ for sd_name, hf_name in unet_conversion_map:
180
+ if strict_mapping:
181
+ if hf_name in mapping:
182
+ mapping[hf_name] = sd_name
183
+ else:
184
+ mapping[hf_name] = sd_name
185
+ for k, v in mapping.items():
186
+ if "resnets" in k:
187
+ for sd_part, hf_part in unet_conversion_map_resnet:
188
+ v = v.replace(hf_part, sd_part)
189
+ mapping[k] = v
190
+ # elif "temp_convs" in k:
191
+ # for sd_part, hf_part in unet_conversion_map_resnet:
192
+ # v = v.replace(hf_part, sd_part)
193
+ # mapping[k] = v
194
+ for k, v in mapping.items():
195
+ for sd_part, hf_part in unet_conversion_map_layer:
196
+ v = v.replace(hf_part, sd_part)
197
+ mapping[k] = v
198
+
199
+
200
+ # there must be a pattern, but I don't want to bother atm
201
+ do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]]
202
+ print (do_not_unsqueeze)
203
+
204
+ new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()}
205
+ # HACK: idk why the hell it does not work with list comprehension
206
+ for k, v in new_state_dict.items():
207
+ has_k = False
208
+ for n in do_not_unsqueeze:
209
+ if k == n:
210
+ has_k = True
211
+
212
+ if has_k:
213
+ v = v.squeeze(-1)
214
+ new_state_dict[k] = v
215
+
216
+ return new_state_dict
217
+
218
+ # TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha
219
+
220
+ # =========================#
221
+ # Text Encoder Conversion #
222
+ # =========================#
223
+
224
+ # IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha
225
+
226
+ # =========================#
227
+ # Text Encoder Conversion #
228
+ # =========================#
229
+
230
+
231
+ textenc_conversion_lst = [
232
+ # (stable-diffusion, HF Diffusers)
233
+ ("resblocks.", "text_model.encoder.layers."),
234
+ ("ln_1", "layer_norm1"),
235
+ ("ln_2", "layer_norm2"),
236
+ (".c_fc.", ".fc1."),
237
+ (".c_proj.", ".fc2."),
238
+ (".attn", ".self_attn"),
239
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
240
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
241
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
242
+ ]
243
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
244
+ textenc_pattern = re.compile("|".join(protected.keys()))
245
+
246
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
247
+ code2idx = {"q": 0, "k": 1, "v": 2}
248
+
249
+
250
+ def convert_text_enc_state_dict_v20(text_enc_dict):
251
+ #print ('Converting the text encoder')
252
+ new_state_dict = {}
253
+ capture_qkv_weight = {}
254
+ capture_qkv_bias = {}
255
+ for k, v in text_enc_dict.items():
256
+ if (
257
+ k.endswith(".self_attn.q_proj.weight")
258
+ or k.endswith(".self_attn.k_proj.weight")
259
+ or k.endswith(".self_attn.v_proj.weight")
260
+ ):
261
+ k_pre = k[: -len(".q_proj.weight")]
262
+ k_code = k[-len("q_proj.weight")]
263
+ if k_pre not in capture_qkv_weight:
264
+ capture_qkv_weight[k_pre] = [None, None, None]
265
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
266
+ continue
267
+
268
+ if (
269
+ k.endswith(".self_attn.q_proj.bias")
270
+ or k.endswith(".self_attn.k_proj.bias")
271
+ or k.endswith(".self_attn.v_proj.bias")
272
+ ):
273
+ k_pre = k[: -len(".q_proj.bias")]
274
+ k_code = k[-len("q_proj.bias")]
275
+ if k_pre not in capture_qkv_bias:
276
+ capture_qkv_bias[k_pre] = [None, None, None]
277
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
278
+ continue
279
+
280
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
281
+ new_state_dict[relabelled_key] = v
282
+
283
+ for k_pre, tensors in capture_qkv_weight.items():
284
+ if None in tensors:
285
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
286
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
287
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
288
+
289
+ for k_pre, tensors in capture_qkv_bias.items():
290
+ if None in tensors:
291
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
292
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
293
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
294
+
295
+ return new_state_dict
296
+
297
+
298
+ def convert_text_enc_state_dict(text_enc_dict):
299
+ return text_enc_dict
300
+
301
+ textenc_conversion_lst = [
302
+ # (stable-diffusion, HF Diffusers)
303
+ ("resblocks.", "text_model.encoder.layers."),
304
+ ("ln_1", "layer_norm1"),
305
+ ("ln_2", "layer_norm2"),
306
+ (".c_fc.", ".fc1."),
307
+ (".c_proj.", ".fc2."),
308
+ (".attn", ".self_attn"),
309
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
310
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
311
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
312
+ ]
313
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
314
+ textenc_pattern = re.compile("|".join(protected.keys()))
315
+
316
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
317
+ code2idx = {"q": 0, "k": 1, "v": 2}
318
+
319
+
320
+ def convert_text_enc_state_dict_v20(text_enc_dict):
321
+ new_state_dict = {}
322
+ capture_qkv_weight = {}
323
+ capture_qkv_bias = {}
324
+ for k, v in text_enc_dict.items():
325
+ if (
326
+ k.endswith(".self_attn.q_proj.weight")
327
+ or k.endswith(".self_attn.k_proj.weight")
328
+ or k.endswith(".self_attn.v_proj.weight")
329
+ ):
330
+ k_pre = k[: -len(".q_proj.weight")]
331
+ k_code = k[-len("q_proj.weight")]
332
+ if k_pre not in capture_qkv_weight:
333
+ capture_qkv_weight[k_pre] = [None, None, None]
334
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
335
+ continue
336
+
337
+ if (
338
+ k.endswith(".self_attn.q_proj.bias")
339
+ or k.endswith(".self_attn.k_proj.bias")
340
+ or k.endswith(".self_attn.v_proj.bias")
341
+ ):
342
+ k_pre = k[: -len(".q_proj.bias")]
343
+ k_code = k[-len("q_proj.bias")]
344
+ if k_pre not in capture_qkv_bias:
345
+ capture_qkv_bias[k_pre] = [None, None, None]
346
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
347
+ continue
348
+
349
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
350
+ new_state_dict[relabelled_key] = v
351
+
352
+ for k_pre, tensors in capture_qkv_weight.items():
353
+ if None in tensors:
354
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
355
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
356
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
357
+
358
+ for k_pre, tensors in capture_qkv_bias.items():
359
+ if None in tensors:
360
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
361
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
362
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
363
+
364
+ return new_state_dict
365
+
366
+
367
+ def convert_text_enc_state_dict(text_enc_dict):
368
+ return text_enc_dict
369
+
370
+ if __name__ == "__main__":
371
+ parser = argparse.ArgumentParser()
372
+
373
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
374
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
375
+ parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.")
376
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
377
+ parser.add_argument(
378
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
379
+ )
380
+
381
+ args = parser.parse_args()
382
+
383
+ assert args.model_path is not None, "Must provide a model path!"
384
+
385
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
386
+
387
+ assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!"
388
+
389
+ # Path for safetensors
390
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
391
+ #vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
392
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
393
+
394
+ # Load models from safetensors if it exists, if it doesn't pytorch
395
+ if osp.exists(unet_path):
396
+ unet_state_dict = load_file(unet_path, device="cpu")
397
+ else:
398
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
399
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
400
+
401
+ # if osp.exists(vae_path):
402
+ # vae_state_dict = load_file(vae_path, device="cpu")
403
+ # else:
404
+ # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
405
+ # vae_state_dict = torch.load(vae_path, map_location="cpu")
406
+
407
+ if osp.exists(text_enc_path):
408
+ text_enc_dict = load_file(text_enc_path, device="cpu")
409
+ else:
410
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
411
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
412
+
413
+ # Convert the UNet model
414
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
415
+ #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
416
+
417
+ # Convert the VAE model
418
+ # vae_state_dict = convert_vae_state_dict(vae_state_dict)
419
+ # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
420
+
421
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
422
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
423
+
424
+ if is_v20_model:
425
+
426
+ # MODELSCOPE always uses the 2.X encoder, btw --kabachuha
427
+
428
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
429
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
430
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
431
+ #text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
432
+ else:
433
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
434
+ #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
435
+
436
+ # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
437
+ # Save CLIP and the Diffusion model to their own files
438
+
439
+ #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
440
+ print ('Saving UNET')
441
+ state_dict = {**unet_state_dict}
442
+
443
+ if args.half:
444
+ state_dict = {k: v.half() for k, v in state_dict.items()}
445
+
446
+ if args.use_safetensors:
447
+ save_file(state_dict, args.checkpoint_path)
448
+ else:
449
+ #state_dict = {"state_dict": state_dict}
450
+ torch.save(state_dict, args.checkpoint_path)
451
+
452
+ # TODO: CLIP conversion doesn't work atm
453
+ # print ('Saving CLIP')
454
+ # state_dict = {**text_enc_dict}
455
+
456
+ # if args.half:
457
+ # state_dict = {k: v.half() for k, v in state_dict.items()}
458
+
459
+ # if args.use_safetensors:
460
+ # save_file(state_dict, args.checkpoint_path)
461
+ # else:
462
+ # #state_dict = {"state_dict": state_dict}
463
+ # torch.save(state_dict, args.clip_checkpoint_path)
464
+
465
+ print('Operation successfull')
utils/dataset.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import decord
3
+ import numpy as np
4
+ import random
5
+ import json
6
+ import torchvision
7
+ import torchvision.transforms as T
8
+ import torch
9
+
10
+ from glob import glob
11
+ from PIL import Image
12
+ from itertools import islice
13
+ from pathlib import Path
14
+ from .bucketing import sensible_buckets
15
+
16
+ decord.bridge.set_bridge('torch')
17
+
18
+ from torch.utils.data import Dataset
19
+ from einops import rearrange, repeat
20
+
21
+
22
+ def get_prompt_ids(prompt, tokenizer):
23
+ prompt_ids = tokenizer(
24
+ prompt,
25
+ truncation=True,
26
+ padding="max_length",
27
+ max_length=tokenizer.model_max_length,
28
+ return_tensors="pt",
29
+ ).input_ids
30
+
31
+ return prompt_ids
32
+
33
+
34
+ def read_caption_file(caption_file):
35
+ with open(caption_file, 'r', encoding="utf8") as t:
36
+ return t.read()
37
+
38
+
39
+ def get_text_prompt(
40
+ text_prompt: str = '',
41
+ fallback_prompt: str= '',
42
+ file_path:str = '',
43
+ ext_types=['.mp4'],
44
+ use_caption=False
45
+ ):
46
+ try:
47
+ if use_caption:
48
+ if len(text_prompt) > 1: return text_prompt
49
+ caption_file = ''
50
+ # Use caption on per-video basis (One caption PER video)
51
+ for ext in ext_types:
52
+ maybe_file = file_path.replace(ext, '.txt')
53
+ if maybe_file.endswith(ext_types): continue
54
+ if os.path.exists(maybe_file):
55
+ caption_file = maybe_file
56
+ break
57
+
58
+ if os.path.exists(caption_file):
59
+ return read_caption_file(caption_file)
60
+
61
+ # Return fallback prompt if no conditions are met.
62
+ return fallback_prompt
63
+
64
+ return text_prompt
65
+ except:
66
+ print(f"Couldn't read prompt caption for {file_path}. Using fallback.")
67
+ return fallback_prompt
68
+
69
+
70
+ def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24):
71
+ max_range = len(vr)
72
+ frame_number = sorted((0, start_idx, max_range))[1]
73
+
74
+ frame_range = range(frame_number, max_range, sample_rate)
75
+ frame_range_indices = list(frame_range)[:max_frames]
76
+
77
+ return frame_range_indices
78
+
79
+
80
+ def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch):
81
+ if use_bucketing:
82
+ vr = decord.VideoReader(vid_path)
83
+ resize = get_frame_buckets(vr)
84
+ video = get_frame_batch(vr, resize=resize)
85
+
86
+ else:
87
+ vr = decord.VideoReader(vid_path, width=w, height=h)
88
+ video = get_frame_batch(vr)
89
+
90
+ return video, vr
91
+
92
+
93
+ # https://github.com/ExponentialML/Video-BLIP2-Preprocessor
94
+ class VideoJsonDataset(Dataset):
95
+ def __init__(
96
+ self,
97
+ tokenizer = None,
98
+ width: int = 256,
99
+ height: int = 256,
100
+ n_sample_frames: int = 4,
101
+ sample_start_idx: int = 1,
102
+ frame_step: int = 1,
103
+ json_path: str ="",
104
+ json_data = None,
105
+ vid_data_key: str = "video_path",
106
+ preprocessed: bool = False,
107
+ use_bucketing: bool = False,
108
+ **kwargs
109
+ ):
110
+ self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
111
+ self.use_bucketing = use_bucketing
112
+ self.tokenizer = tokenizer
113
+ self.preprocessed = preprocessed
114
+
115
+ self.vid_data_key = vid_data_key
116
+ self.train_data = self.load_from_json(json_path, json_data)
117
+
118
+ self.width = width
119
+ self.height = height
120
+
121
+ self.n_sample_frames = n_sample_frames
122
+ self.sample_start_idx = sample_start_idx
123
+ self.frame_step = frame_step
124
+
125
+ def build_json(self, json_data):
126
+ extended_data = []
127
+ for data in json_data['data']:
128
+ for nested_data in data['data']:
129
+ self.build_json_dict(
130
+ data,
131
+ nested_data,
132
+ extended_data
133
+ )
134
+ json_data = extended_data
135
+ return json_data
136
+
137
+ def build_json_dict(self, data, nested_data, extended_data):
138
+ clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
139
+
140
+ extended_data.append({
141
+ self.vid_data_key: data[self.vid_data_key],
142
+ 'frame_index': nested_data['frame_index'],
143
+ 'prompt': nested_data['prompt'],
144
+ 'clip_path': clip_path
145
+ })
146
+
147
+ def load_from_json(self, path, json_data):
148
+ try:
149
+ with open(path) as jpath:
150
+ print(f"Loading JSON from {path}")
151
+ json_data = json.load(jpath)
152
+
153
+ return self.build_json(json_data)
154
+
155
+ except:
156
+ self.train_data = []
157
+ print("Non-existant JSON path. Skipping.")
158
+
159
+ def validate_json(self, base_path, path):
160
+ return os.path.exists(f"{base_path}/{path}")
161
+
162
+ def get_frame_range(self, vr):
163
+ return get_video_frames(
164
+ vr,
165
+ self.sample_start_idx,
166
+ self.frame_step,
167
+ self.n_sample_frames
168
+ )
169
+
170
+ def get_vid_idx(self, vr, vid_data=None):
171
+ frames = self.n_sample_frames
172
+
173
+ if vid_data is not None:
174
+ idx = vid_data['frame_index']
175
+ else:
176
+ idx = self.sample_start_idx
177
+
178
+ return idx
179
+
180
+ def get_frame_buckets(self, vr):
181
+ _, h, w = vr[0].shape
182
+ width, height = sensible_buckets(self.width, self.height, h, w)
183
+ # width, height = self.width, self.height
184
+ resize = T.transforms.Resize((height, width), antialias=True)
185
+
186
+ return resize
187
+
188
+ def get_frame_batch(self, vr, resize=None):
189
+ frame_range = self.get_frame_range(vr)
190
+ frames = vr.get_batch(frame_range)
191
+ video = rearrange(frames, "f h w c -> f c h w")
192
+
193
+ if resize is not None: video = resize(video)
194
+ return video
195
+
196
+ def process_video_wrapper(self, vid_path):
197
+ video, vr = process_video(
198
+ vid_path,
199
+ self.use_bucketing,
200
+ self.width,
201
+ self.height,
202
+ self.get_frame_buckets,
203
+ self.get_frame_batch
204
+ )
205
+
206
+ return video, vr
207
+
208
+ def train_data_batch(self, index):
209
+
210
+ # If we are training on individual clips.
211
+ if 'clip_path' in self.train_data[index] and \
212
+ self.train_data[index]['clip_path'] is not None:
213
+
214
+ vid_data = self.train_data[index]
215
+
216
+ clip_path = vid_data['clip_path']
217
+
218
+ # Get video prompt
219
+ prompt = vid_data['prompt']
220
+
221
+ video, _ = self.process_video_wrapper(clip_path)
222
+
223
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
224
+
225
+ return video, prompt, prompt_ids
226
+
227
+ # Assign train data
228
+ train_data = self.train_data[index]
229
+
230
+ # Get the frame of the current index.
231
+ self.sample_start_idx = train_data['frame_index']
232
+
233
+ # Initialize resize
234
+ resize = None
235
+
236
+ video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
237
+
238
+ # Get video prompt
239
+ prompt = train_data['prompt']
240
+ vr.seek(0)
241
+
242
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
243
+
244
+ return video, prompt, prompt_ids
245
+
246
+ @staticmethod
247
+ def __getname__(): return 'json'
248
+
249
+ def __len__(self):
250
+ if self.train_data is not None:
251
+ return len(self.train_data)
252
+ else:
253
+ return 0
254
+
255
+ def __getitem__(self, index):
256
+
257
+ # Initialize variables
258
+ video = None
259
+ prompt = None
260
+ prompt_ids = None
261
+
262
+ # Use default JSON training
263
+ if self.train_data is not None:
264
+ video, prompt, prompt_ids = self.train_data_batch(index)
265
+
266
+ example = {
267
+ "pixel_values": (video / 127.5 - 1.0),
268
+ "prompt_ids": prompt_ids[0],
269
+ "text_prompt": prompt,
270
+ 'dataset': self.__getname__()
271
+ }
272
+
273
+ return example
274
+
275
+
276
+ class SingleVideoDataset(Dataset):
277
+ def __init__(
278
+ self,
279
+ tokenizer = None,
280
+ width: int = 256,
281
+ height: int = 256,
282
+ n_sample_frames: int = 4,
283
+ frame_step: int = 1,
284
+ single_video_path: str = "",
285
+ single_video_prompt: str = "",
286
+ use_caption: bool = False,
287
+ use_bucketing: bool = False,
288
+ **kwargs
289
+ ):
290
+ self.tokenizer = tokenizer
291
+ self.use_bucketing = use_bucketing
292
+ self.frames = []
293
+ self.index = 1
294
+
295
+ self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
296
+ self.n_sample_frames = n_sample_frames
297
+ self.frame_step = frame_step
298
+
299
+ self.single_video_path = single_video_path
300
+ self.single_video_prompt = single_video_prompt
301
+
302
+ self.width = width
303
+ self.height = height
304
+ def create_video_chunks(self):
305
+ vr = decord.VideoReader(self.single_video_path)
306
+ vr_range = range(0, len(vr), self.frame_step)
307
+
308
+ self.frames = list(self.chunk(vr_range, self.n_sample_frames))
309
+ return self.frames
310
+
311
+ def chunk(self, it, size):
312
+ it = iter(it)
313
+ return iter(lambda: tuple(islice(it, size)), ())
314
+
315
+ def get_frame_batch(self, vr, resize=None):
316
+ index = self.index
317
+ frames = vr.get_batch(self.frames[self.index])
318
+ video = rearrange(frames, "f h w c -> f c h w")
319
+
320
+ if resize is not None: video = resize(video)
321
+ return video
322
+
323
+ def get_frame_buckets(self, vr):
324
+ _, h, w = vr[0].shape
325
+ # width, height = sensible_buckets(self.width, self.height, h, w)
326
+ width, height = self.width, self.height
327
+ resize = T.transforms.Resize((height, width), antialias=True)
328
+
329
+ return resize
330
+
331
+ def process_video_wrapper(self, vid_path):
332
+ video, vr = process_video(
333
+ vid_path,
334
+ self.use_bucketing,
335
+ self.width,
336
+ self.height,
337
+ self.get_frame_buckets,
338
+ self.get_frame_batch
339
+ )
340
+
341
+ return video, vr
342
+
343
+ def single_video_batch(self, index):
344
+ train_data = self.single_video_path
345
+ self.index = index
346
+
347
+ if train_data.endswith(self.vid_types):
348
+ video, _ = self.process_video_wrapper(train_data)
349
+
350
+ prompt = self.single_video_prompt
351
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
352
+
353
+ return video, prompt, prompt_ids
354
+ else:
355
+ raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
356
+
357
+ @staticmethod
358
+ def __getname__(): return 'single_video'
359
+
360
+ def __len__(self):
361
+
362
+ return len(self.create_video_chunks())
363
+
364
+ def __getitem__(self, index):
365
+
366
+ video, prompt, prompt_ids = self.single_video_batch(index)
367
+
368
+ example = {
369
+ "pixel_values": (video / 127.5 - 1.0),
370
+ "prompt_ids": prompt_ids[0],
371
+ "text_prompt": prompt,
372
+ 'dataset': self.__getname__()
373
+ }
374
+
375
+ return example
376
+
377
+
378
+ class ImageDataset(Dataset):
379
+
380
+ def __init__(
381
+ self,
382
+ tokenizer = None,
383
+ width: int = 256,
384
+ height: int = 256,
385
+ base_width: int = 256,
386
+ base_height: int = 256,
387
+ use_caption: bool = False,
388
+ image_dir: str = '',
389
+ single_img_prompt: str = '',
390
+ use_bucketing: bool = False,
391
+ fallback_prompt: str = '',
392
+ **kwargs
393
+ ):
394
+ self.tokenizer = tokenizer
395
+ self.img_types = (".png", ".jpg", ".jpeg", '.bmp')
396
+ self.use_bucketing = use_bucketing
397
+
398
+ self.image_dir = self.get_images_list(image_dir)
399
+ self.fallback_prompt = fallback_prompt
400
+
401
+ self.use_caption = use_caption
402
+ self.single_img_prompt = single_img_prompt
403
+
404
+ self.width = width
405
+ self.height = height
406
+
407
+ def get_images_list(self, image_dir):
408
+ if os.path.exists(image_dir):
409
+ imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)]
410
+ full_img_dir = []
411
+
412
+ for img in imgs:
413
+ full_img_dir.append(f"{image_dir}/{img}")
414
+
415
+ return sorted(full_img_dir)
416
+
417
+ return ['']
418
+
419
+ def image_batch(self, index):
420
+ train_data = self.image_dir[index]
421
+ img = train_data
422
+
423
+ try:
424
+ img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
425
+ except:
426
+ img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
427
+
428
+ width = self.width
429
+ height = self.height
430
+
431
+ if self.use_bucketing:
432
+ _, h, w = img.shape
433
+ width, height = sensible_buckets(width, height, w, h)
434
+
435
+ resize = T.transforms.Resize((height, width), antialias=True)
436
+
437
+ img = resize(img)
438
+ img = repeat(img, 'c h w -> f c h w', f=1)
439
+
440
+ prompt = get_text_prompt(
441
+ file_path=train_data,
442
+ text_prompt=self.single_img_prompt,
443
+ fallback_prompt=self.fallback_prompt,
444
+ ext_types=self.img_types,
445
+ use_caption=True
446
+ )
447
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
448
+
449
+ return img, prompt, prompt_ids
450
+
451
+ @staticmethod
452
+ def __getname__(): return 'image'
453
+
454
+ def __len__(self):
455
+ # Image directory
456
+ if os.path.exists(self.image_dir[0]):
457
+ return len(self.image_dir)
458
+ else:
459
+ return 0
460
+
461
+ def __getitem__(self, index):
462
+ img, prompt, prompt_ids = self.image_batch(index)
463
+ example = {
464
+ "pixel_values": (img / 127.5 - 1.0),
465
+ "prompt_ids": prompt_ids[0],
466
+ "text_prompt": prompt,
467
+ 'dataset': self.__getname__()
468
+ }
469
+
470
+ return example
471
+
472
+
473
+ class VideoFolderDataset(Dataset):
474
+ def __init__(
475
+ self,
476
+ tokenizer=None,
477
+ width: int = 256,
478
+ height: int = 256,
479
+ n_sample_frames: int = 16,
480
+ fps: int = 8,
481
+ path: str = "./data",
482
+ fallback_prompt: str = "",
483
+ use_bucketing: bool = False,
484
+ **kwargs
485
+ ):
486
+ self.tokenizer = tokenizer
487
+ self.use_bucketing = use_bucketing
488
+
489
+ self.fallback_prompt = fallback_prompt
490
+
491
+ self.video_files = glob(f"{path}/*.mp4")
492
+
493
+ self.width = width
494
+ self.height = height
495
+
496
+ self.n_sample_frames = n_sample_frames
497
+ self.fps = fps
498
+
499
+ def get_frame_buckets(self, vr):
500
+ _, h, w = vr[0].shape
501
+ width, height = sensible_buckets(self.width, self.height, h, w)
502
+ # width, height = self.width, self.height
503
+ resize = T.transforms.Resize((height, width), antialias=True)
504
+
505
+ return resize
506
+
507
+ def get_frame_batch(self, vr, resize=None):
508
+ n_sample_frames = self.n_sample_frames
509
+ native_fps = vr.get_avg_fps()
510
+
511
+ every_nth_frame = max(1, round(native_fps / self.fps))
512
+ every_nth_frame = min(len(vr), every_nth_frame)
513
+
514
+ effective_length = len(vr) // every_nth_frame
515
+ if effective_length < n_sample_frames:
516
+ n_sample_frames = effective_length
517
+
518
+ effective_idx = random.randint(0, (effective_length - n_sample_frames))
519
+ idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
520
+
521
+ video = vr.get_batch(idxs)
522
+ video = rearrange(video, "f h w c -> f c h w")
523
+
524
+ if resize is not None: video = resize(video)
525
+ return video, vr
526
+
527
+ def process_video_wrapper(self, vid_path):
528
+ video, vr = process_video(
529
+ vid_path,
530
+ self.use_bucketing,
531
+ self.width,
532
+ self.height,
533
+ self.get_frame_buckets,
534
+ self.get_frame_batch
535
+ )
536
+ return video, vr
537
+
538
+ def get_prompt_ids(self, prompt):
539
+ return self.tokenizer(
540
+ prompt,
541
+ truncation=True,
542
+ padding="max_length",
543
+ max_length=self.tokenizer.model_max_length,
544
+ return_tensors="pt",
545
+ ).input_ids
546
+
547
+ @staticmethod
548
+ def __getname__(): return 'folder'
549
+
550
+ def __len__(self):
551
+ return len(self.video_files)
552
+
553
+ def __getitem__(self, index):
554
+
555
+ video, _ = self.process_video_wrapper(self.video_files[index])
556
+
557
+ prompt = self.fallback_prompt
558
+
559
+ prompt_ids = self.get_prompt_ids(prompt)
560
+
561
+ return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}
562
+
563
+
564
+ class CachedDataset(Dataset):
565
+ def __init__(self,cache_dir: str = ''):
566
+ self.cache_dir = cache_dir
567
+ self.cached_data_list = self.get_files_list()
568
+
569
+ def get_files_list(self):
570
+ tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
571
+ return sorted(tensors_list)
572
+
573
+ def __len__(self):
574
+ return len(self.cached_data_list)
575
+
576
+ def __getitem__(self, index):
577
+ cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
578
+ return cached_latent
utils/ddim_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Union
3
+
4
+ import torch
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ # DDIM Inversion
10
+ @torch.no_grad()
11
+ def init_prompt(prompt, pipeline):
12
+ uncond_input = pipeline.tokenizer(
13
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
14
+ return_tensors="pt"
15
+ )
16
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
17
+ text_input = pipeline.tokenizer(
18
+ [prompt],
19
+ padding="max_length",
20
+ max_length=pipeline.tokenizer.model_max_length,
21
+ truncation=True,
22
+ return_tensors="pt",
23
+ )
24
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
25
+ context = torch.cat([uncond_embeddings, text_embeddings])
26
+
27
+ return context
28
+
29
+
30
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
31
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
32
+ timestep, next_timestep = min(
33
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
34
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
35
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
36
+ beta_prod_t = 1 - alpha_prod_t
37
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
38
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
39
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
40
+ return next_sample
41
+
42
+
43
+ def get_noise_pred_single(latents, t, context, unet):
44
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
45
+ return noise_pred
46
+
47
+
48
+ @torch.no_grad()
49
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
50
+ context = init_prompt(prompt, pipeline)
51
+ uncond_embeddings, cond_embeddings = context.chunk(2)
52
+ all_latent = [latent]
53
+ latent = latent.clone().detach()
54
+ for i in tqdm(range(num_inv_steps)):
55
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
56
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
57
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
58
+ all_latent.append(latent)
59
+ return all_latent
60
+
61
+
62
+ @torch.no_grad()
63
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
64
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
65
+ return ddim_latents
utils/lora.py ADDED
@@ -0,0 +1,1483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from itertools import groupby
4
+ import os
5
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ try:
14
+ from safetensors.torch import safe_open
15
+ from safetensors.torch import save_file as safe_save
16
+
17
+ safetensors_available = True
18
+ except ImportError:
19
+ from .safe_open import safe_open
20
+
21
+ def safe_save(
22
+ tensors: Dict[str, torch.Tensor],
23
+ filename: str,
24
+ metadata: Optional[Dict[str, str]] = None,
25
+ ) -> None:
26
+ raise EnvironmentError(
27
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
28
+ )
29
+
30
+ safetensors_available = False
31
+
32
+
33
+ class LoraInjectedLinear(nn.Module):
34
+ def __init__(
35
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
36
+ ):
37
+ super().__init__()
38
+
39
+ if r > min(in_features, out_features):
40
+ #raise ValueError(
41
+ # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
42
+ #)
43
+ print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
44
+ r = min(in_features, out_features)
45
+
46
+ self.r = r
47
+ self.linear = nn.Linear(in_features, out_features, bias)
48
+ self.lora_down = nn.Linear(in_features, r, bias=False)
49
+ self.dropout = nn.Dropout(dropout_p)
50
+ self.lora_up = nn.Linear(r, out_features, bias=False)
51
+ self.scale = scale
52
+ self.selector = nn.Identity()
53
+
54
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
55
+ nn.init.zeros_(self.lora_up.weight)
56
+
57
+ def forward(self, input):
58
+ return (
59
+ self.linear(input)
60
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
61
+ * self.scale
62
+ )
63
+
64
+ def realize_as_lora(self):
65
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
66
+
67
+ def set_selector_from_diag(self, diag: torch.Tensor):
68
+ # diag is a 1D tensor of size (r,)
69
+ assert diag.shape == (self.r,)
70
+ self.selector = nn.Linear(self.r, self.r, bias=False)
71
+ self.selector.weight.data = torch.diag(diag)
72
+ self.selector.weight.data = self.selector.weight.data.to(
73
+ self.lora_up.weight.device
74
+ ).to(self.lora_up.weight.dtype)
75
+
76
+
77
+ class MultiLoraInjectedLinear(nn.Module):
78
+ def __init__(
79
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, lora_num=1, scales=[1.0]
80
+ ):
81
+ super().__init__()
82
+
83
+ if r > min(in_features, out_features):
84
+ #raise ValueError(
85
+ # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
86
+ #)
87
+ print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
88
+ r = min(in_features, out_features)
89
+
90
+ self.r = r
91
+ self.linear = nn.Linear(in_features, out_features, bias)
92
+
93
+ for i in range(lora_num):
94
+ if i==0:
95
+ self.lora_down =[nn.Linear(in_features, r, bias=False)]
96
+ self.dropout = [nn.Dropout(dropout_p)]
97
+ self.lora_up = [nn.Linear(r, out_features, bias=False)]
98
+ self.scale = scales[i]
99
+ self.selector = [nn.Identity()]
100
+ else:
101
+ self.lora_down.append(nn.Linear(in_features, r, bias=False))
102
+ self.dropout.append( nn.Dropout(dropout_p))
103
+ self.lora_up.append( nn.Linear(r, out_features, bias=False))
104
+ self.scale.append(scales[i])
105
+
106
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
107
+ nn.init.zeros_(self.lora_up.weight)
108
+
109
+ def forward(self, input):
110
+ return (
111
+ self.linear(input)
112
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
113
+ * self.scale
114
+ )
115
+
116
+ def realize_as_lora(self):
117
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
118
+
119
+ def set_selector_from_diag(self, diag: torch.Tensor):
120
+ # diag is a 1D tensor of size (r,)
121
+ assert diag.shape == (self.r,)
122
+ self.selector = nn.Linear(self.r, self.r, bias=False)
123
+ self.selector.weight.data = torch.diag(diag)
124
+ self.selector.weight.data = self.selector.weight.data.to(
125
+ self.lora_up.weight.device
126
+ ).to(self.lora_up.weight.dtype)
127
+
128
+
129
+ class LoraInjectedConv2d(nn.Module):
130
+ def __init__(
131
+ self,
132
+ in_channels: int,
133
+ out_channels: int,
134
+ kernel_size,
135
+ stride=1,
136
+ padding=0,
137
+ dilation=1,
138
+ groups: int = 1,
139
+ bias: bool = True,
140
+ r: int = 4,
141
+ dropout_p: float = 0.1,
142
+ scale: float = 1.0,
143
+ ):
144
+ super().__init__()
145
+ if r > min(in_channels, out_channels):
146
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
147
+ r = min(in_channels, out_channels)
148
+
149
+ self.r = r
150
+ self.conv = nn.Conv2d(
151
+ in_channels=in_channels,
152
+ out_channels=out_channels,
153
+ kernel_size=kernel_size,
154
+ stride=stride,
155
+ padding=padding,
156
+ dilation=dilation,
157
+ groups=groups,
158
+ bias=bias,
159
+ )
160
+
161
+ self.lora_down = nn.Conv2d(
162
+ in_channels=in_channels,
163
+ out_channels=r,
164
+ kernel_size=kernel_size,
165
+ stride=stride,
166
+ padding=padding,
167
+ dilation=dilation,
168
+ groups=groups,
169
+ bias=False,
170
+ )
171
+ self.dropout = nn.Dropout(dropout_p)
172
+ self.lora_up = nn.Conv2d(
173
+ in_channels=r,
174
+ out_channels=out_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0,
178
+ bias=False,
179
+ )
180
+ self.selector = nn.Identity()
181
+ self.scale = scale
182
+
183
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
184
+ nn.init.zeros_(self.lora_up.weight)
185
+
186
+ def forward(self, input):
187
+ return (
188
+ self.conv(input)
189
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
190
+ * self.scale
191
+ )
192
+
193
+ def realize_as_lora(self):
194
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
195
+
196
+ def set_selector_from_diag(self, diag: torch.Tensor):
197
+ # diag is a 1D tensor of size (r,)
198
+ assert diag.shape == (self.r,)
199
+ self.selector = nn.Conv2d(
200
+ in_channels=self.r,
201
+ out_channels=self.r,
202
+ kernel_size=1,
203
+ stride=1,
204
+ padding=0,
205
+ bias=False,
206
+ )
207
+ self.selector.weight.data = torch.diag(diag)
208
+
209
+ # same device + dtype as lora_up
210
+ self.selector.weight.data = self.selector.weight.data.to(
211
+ self.lora_up.weight.device
212
+ ).to(self.lora_up.weight.dtype)
213
+
214
+ class LoraInjectedConv3d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ kernel_size: (3, 1, 1),
220
+ padding: (1, 0, 0),
221
+ bias: bool = False,
222
+ r: int = 4,
223
+ dropout_p: float = 0,
224
+ scale: float = 1.0,
225
+ ):
226
+ super().__init__()
227
+ if r > min(in_channels, out_channels):
228
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
229
+ r = min(in_channels, out_channels)
230
+
231
+ self.r = r
232
+ self.kernel_size = kernel_size
233
+ self.padding = padding
234
+ self.conv = nn.Conv3d(
235
+ in_channels=in_channels,
236
+ out_channels=out_channels,
237
+ kernel_size=kernel_size,
238
+ padding=padding,
239
+ )
240
+
241
+ self.lora_down = nn.Conv3d(
242
+ in_channels=in_channels,
243
+ out_channels=r,
244
+ kernel_size=kernel_size,
245
+ bias=False,
246
+ padding=padding
247
+ )
248
+ self.dropout = nn.Dropout(dropout_p)
249
+ self.lora_up = nn.Conv3d(
250
+ in_channels=r,
251
+ out_channels=out_channels,
252
+ kernel_size=1,
253
+ stride=1,
254
+ padding=0,
255
+ bias=False,
256
+ )
257
+ self.selector = nn.Identity()
258
+ self.scale = scale
259
+
260
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
261
+ nn.init.zeros_(self.lora_up.weight)
262
+
263
+ def forward(self, input):
264
+ return (
265
+ self.conv(input)
266
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
267
+ * self.scale
268
+ )
269
+
270
+ def realize_as_lora(self):
271
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
272
+
273
+ def set_selector_from_diag(self, diag: torch.Tensor):
274
+ # diag is a 1D tensor of size (r,)
275
+ assert diag.shape == (self.r,)
276
+ self.selector = nn.Conv3d(
277
+ in_channels=self.r,
278
+ out_channels=self.r,
279
+ kernel_size=1,
280
+ stride=1,
281
+ padding=0,
282
+ bias=False,
283
+ )
284
+ self.selector.weight.data = torch.diag(diag)
285
+
286
+ # same device + dtype as lora_up
287
+ self.selector.weight.data = self.selector.weight.data.to(
288
+ self.lora_up.weight.device
289
+ ).to(self.lora_up.weight.dtype)
290
+
291
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
292
+
293
+ UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
294
+
295
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
296
+
297
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
298
+
299
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
300
+
301
+ EMBED_FLAG = "<embed>"
302
+
303
+
304
+ def _find_children(
305
+ model,
306
+ search_class: List[Type[nn.Module]] = [nn.Linear],
307
+ ):
308
+ """
309
+ Find all modules of a certain class (or union of classes).
310
+
311
+ Returns all matching modules, along with the parent of those moduless and the
312
+ names they are referenced by.
313
+ """
314
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
315
+ for parent in model.modules():
316
+ for name, module in parent.named_children():
317
+ if any([isinstance(module, _class) for _class in search_class]):
318
+ yield parent, name, module
319
+
320
+
321
+ def _find_modules_v2(
322
+ model,
323
+ ancestor_class: Optional[Set[str]] = None,
324
+ search_class: List[Type[nn.Module]] = [nn.Linear],
325
+ exclude_children_of: Optional[List[Type[nn.Module]]] = None,
326
+ # [
327
+ # LoraInjectedLinear,
328
+ # LoraInjectedConv2d,
329
+ # LoraInjectedConv3d
330
+ # ],
331
+ ):
332
+ """
333
+ Find all modules of a certain class (or union of classes) that are direct or
334
+ indirect descendants of other modules of a certain class (or union of classes).
335
+
336
+ Returns all matching modules, along with the parent of those moduless and the
337
+ names they are referenced by.
338
+ """
339
+
340
+ # Get the targets we should replace all linears under
341
+ if ancestor_class is not None:
342
+ ancestors = (
343
+ module
344
+ for name, module in model.named_modules()
345
+ if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name)
346
+ )
347
+ else:
348
+ # this, incase you want to naively iterate over all modules.
349
+ ancestors = [module for module in model.modules()]
350
+
351
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
352
+ for ancestor in ancestors:
353
+ for fullname, module in ancestor.named_modules():
354
+ if any([isinstance(module, _class) for _class in search_class]):
355
+ continue_flag = True
356
+ if 'Transformer2DModel' in ancestor_class and ('attn1' in fullname or 'ff' in fullname):
357
+ continue_flag = False
358
+ if 'TransformerTemporalModel' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname):
359
+ continue_flag = False
360
+ if continue_flag:
361
+ continue
362
+ # Find the direct parent if this is a descendant, not a child, of target
363
+ *path, name = fullname.split(".")
364
+ parent = ancestor
365
+ while path:
366
+ parent = parent.get_submodule(path.pop(0))
367
+ # Skip this linear if it's a child of a LoraInjectedLinear
368
+ if exclude_children_of and any(
369
+ [isinstance(parent, _class) for _class in exclude_children_of]
370
+ ):
371
+ continue
372
+ if name in ['lora_up', 'dropout', 'lora_down']:
373
+ continue
374
+ # Otherwise, yield it
375
+ yield parent, name, module
376
+
377
+
378
+ def _find_modules_old(
379
+ model,
380
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
381
+ search_class: List[Type[nn.Module]] = [nn.Linear],
382
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
383
+ ):
384
+ ret = []
385
+ for _module in model.modules():
386
+ if _module.__class__.__name__ in ancestor_class:
387
+
388
+ for name, _child_module in _module.named_modules():
389
+ if _child_module.__class__ in search_class:
390
+ ret.append((_module, name, _child_module))
391
+ print(ret)
392
+ return ret
393
+
394
+
395
+ _find_modules = _find_modules_v2
396
+
397
+
398
+ def inject_trainable_lora(
399
+ model: nn.Module,
400
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
401
+ r: int = 4,
402
+ loras=None, # path to lora .pt
403
+ verbose: bool = False,
404
+ dropout_p: float = 0.0,
405
+ scale: float = 1.0,
406
+ ):
407
+ """
408
+ inject lora into model, and returns lora parameter groups.
409
+ """
410
+
411
+ require_grad_params = []
412
+ names = []
413
+
414
+ if loras != None:
415
+ loras = torch.load(loras)
416
+
417
+ for _module, name, _child_module in _find_modules(
418
+ model, target_replace_module, search_class=[nn.Linear]
419
+ ):
420
+ weight = _child_module.weight
421
+ bias = _child_module.bias
422
+ if verbose:
423
+ print("LoRA Injection : injecting lora into ", name)
424
+ print("LoRA Injection : weight shape", weight.shape)
425
+ _tmp = LoraInjectedLinear(
426
+ _child_module.in_features,
427
+ _child_module.out_features,
428
+ _child_module.bias is not None,
429
+ r=r,
430
+ dropout_p=dropout_p,
431
+ scale=scale,
432
+ )
433
+ _tmp.linear.weight = weight
434
+ if bias is not None:
435
+ _tmp.linear.bias = bias
436
+
437
+ # switch the module
438
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
439
+ _module._modules[name] = _tmp
440
+
441
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
442
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
443
+
444
+ if loras != None:
445
+ _module._modules[name].lora_up.weight = loras.pop(0)
446
+ _module._modules[name].lora_down.weight = loras.pop(0)
447
+
448
+ _module._modules[name].lora_up.weight.requires_grad = True
449
+ _module._modules[name].lora_down.weight.requires_grad = True
450
+ names.append(name)
451
+
452
+ return require_grad_params, names
453
+
454
+
455
+ def inject_trainable_lora_extended(
456
+ model: nn.Module,
457
+ target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
458
+ r: int = 4,
459
+ loras=None, # path to lora .pt
460
+ dropout_p: float = 0.0,
461
+ scale: float = 1.0,
462
+ ):
463
+ """
464
+ inject lora into model, and returns lora parameter groups.
465
+ """
466
+
467
+ require_grad_params = []
468
+ names = []
469
+
470
+ if loras != None:
471
+ loras = torch.load(loras)
472
+ if True:
473
+ for target_replace_module_i in target_replace_module:
474
+ for _module, name, _child_module in _find_modules(
475
+ model, [target_replace_module_i], search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
476
+ ):
477
+ # if name == 'to_q':
478
+ # continue
479
+ if _child_module.__class__ == nn.Linear:
480
+ weight = _child_module.weight
481
+ bias = _child_module.bias
482
+ _tmp = LoraInjectedLinear(
483
+ _child_module.in_features,
484
+ _child_module.out_features,
485
+ _child_module.bias is not None,
486
+ r=r,
487
+ dropout_p=dropout_p,
488
+ scale=scale,
489
+ )
490
+ _tmp.linear.weight = weight
491
+ if bias is not None:
492
+ _tmp.linear.bias = bias
493
+ elif _child_module.__class__ == nn.Conv2d:
494
+ weight = _child_module.weight
495
+ bias = _child_module.bias
496
+ _tmp = LoraInjectedConv2d(
497
+ _child_module.in_channels,
498
+ _child_module.out_channels,
499
+ _child_module.kernel_size,
500
+ _child_module.stride,
501
+ _child_module.padding,
502
+ _child_module.dilation,
503
+ _child_module.groups,
504
+ _child_module.bias is not None,
505
+ r=r,
506
+ dropout_p=dropout_p,
507
+ scale=scale,
508
+ )
509
+
510
+ _tmp.conv.weight = weight
511
+ if bias is not None:
512
+ _tmp.conv.bias = bias
513
+
514
+ elif _child_module.__class__ == nn.Conv3d:
515
+ weight = _child_module.weight
516
+ bias = _child_module.bias
517
+ _tmp = LoraInjectedConv3d(
518
+ _child_module.in_channels,
519
+ _child_module.out_channels,
520
+ bias=_child_module.bias is not None,
521
+ kernel_size=_child_module.kernel_size,
522
+ padding=_child_module.padding,
523
+ r=r,
524
+ dropout_p=dropout_p,
525
+ scale=scale,
526
+ )
527
+
528
+ _tmp.conv.weight = weight
529
+ if bias is not None:
530
+ _tmp.conv.bias = bias
531
+ # switch the module
532
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
533
+ if bias is not None:
534
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
535
+
536
+ _module._modules[name] = _tmp
537
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
538
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
539
+
540
+ if loras != None:
541
+ _module._modules[name].lora_up.weight = loras.pop(0)
542
+ _module._modules[name].lora_down.weight = loras.pop(0)
543
+
544
+ _module._modules[name].lora_up.weight.requires_grad = True
545
+ _module._modules[name].lora_down.weight.requires_grad = True
546
+ names.append(name)
547
+ else:
548
+ for _module, name, _child_module in _find_modules(
549
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
550
+ ):
551
+ if _child_module.__class__ == nn.Linear:
552
+ weight = _child_module.weight
553
+ bias = _child_module.bias
554
+ _tmp = LoraInjectedLinear(
555
+ _child_module.in_features,
556
+ _child_module.out_features,
557
+ _child_module.bias is not None,
558
+ r=r,
559
+ dropout_p=dropout_p,
560
+ scale=scale,
561
+ )
562
+ _tmp.linear.weight = weight
563
+ if bias is not None:
564
+ _tmp.linear.bias = bias
565
+ elif _child_module.__class__ == nn.Conv2d:
566
+ weight = _child_module.weight
567
+ bias = _child_module.bias
568
+ _tmp = LoraInjectedConv2d(
569
+ _child_module.in_channels,
570
+ _child_module.out_channels,
571
+ _child_module.kernel_size,
572
+ _child_module.stride,
573
+ _child_module.padding,
574
+ _child_module.dilation,
575
+ _child_module.groups,
576
+ _child_module.bias is not None,
577
+ r=r,
578
+ dropout_p=dropout_p,
579
+ scale=scale,
580
+ )
581
+
582
+ _tmp.conv.weight = weight
583
+ if bias is not None:
584
+ _tmp.conv.bias = bias
585
+
586
+ elif _child_module.__class__ == nn.Conv3d:
587
+ weight = _child_module.weight
588
+ bias = _child_module.bias
589
+ _tmp = LoraInjectedConv3d(
590
+ _child_module.in_channels,
591
+ _child_module.out_channels,
592
+ bias=_child_module.bias is not None,
593
+ kernel_size=_child_module.kernel_size,
594
+ padding=_child_module.padding,
595
+ r=r,
596
+ dropout_p=dropout_p,
597
+ scale=scale,
598
+ )
599
+
600
+ _tmp.conv.weight = weight
601
+ if bias is not None:
602
+ _tmp.conv.bias = bias
603
+ # switch the module
604
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
605
+ if bias is not None:
606
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
607
+
608
+ _module._modules[name] = _tmp
609
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
610
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
611
+
612
+ if loras != None:
613
+ _module._modules[name].lora_up.weight = loras.pop(0)
614
+ _module._modules[name].lora_down.weight = loras.pop(0)
615
+
616
+ _module._modules[name].lora_up.weight.requires_grad = True
617
+ _module._modules[name].lora_down.weight.requires_grad = True
618
+ names.append(name)
619
+
620
+ return require_grad_params, names
621
+
622
+
623
+ def inject_inferable_lora(
624
+ model,
625
+ lora_path='',
626
+ unet_replace_modules=["UNet3DConditionModel"],
627
+ text_encoder_replace_modules=["CLIPEncoderLayer"],
628
+ is_extended=False,
629
+ r=16
630
+ ):
631
+ from transformers.models.clip import CLIPTextModel
632
+ from diffusers import UNet3DConditionModel
633
+
634
+ def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
635
+ def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
636
+
637
+ if os.path.exists(lora_path):
638
+ try:
639
+ for f in os.listdir(lora_path):
640
+ if f.endswith('.pt'):
641
+ lora_file = os.path.join(lora_path, f)
642
+
643
+ if is_text_model(f):
644
+ monkeypatch_or_replace_lora(
645
+ model.text_encoder,
646
+ torch.load(lora_file),
647
+ target_replace_module=text_encoder_replace_modules,
648
+ r=r
649
+ )
650
+ print("Successfully loaded Text Encoder LoRa.")
651
+ continue
652
+
653
+ if is_unet(f):
654
+ monkeypatch_or_replace_lora_extended(
655
+ model.unet,
656
+ torch.load(lora_file),
657
+ target_replace_module=unet_replace_modules,
658
+ r=r
659
+ )
660
+ print("Successfully loaded UNET LoRa.")
661
+ continue
662
+
663
+ print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
664
+
665
+ except Exception as e:
666
+ print(e)
667
+ print("Couldn't inject LoRA's due to an error.")
668
+
669
+ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
670
+
671
+ loras = []
672
+
673
+ for target_replace_module_i in target_replace_module:
674
+
675
+ for _m, _n, _child_module in _find_modules(
676
+ model,
677
+ [target_replace_module_i],
678
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
679
+ ):
680
+ loras.append((_child_module.lora_up, _child_module.lora_down))
681
+
682
+ if len(loras) == 0:
683
+ raise ValueError("No lora injected.")
684
+
685
+ return loras
686
+
687
+
688
+ def extract_lora_child_module(model, target_replace_module=DEFAULT_TARGET_REPLACE):
689
+
690
+ loras = []
691
+
692
+ for target_replace_module_i in target_replace_module:
693
+
694
+ for _m, _n, _child_module in _find_modules(
695
+ model,
696
+ [target_replace_module_i],
697
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
698
+ ):
699
+ loras.append(_child_module)
700
+
701
+ if len(loras) == 0:
702
+ raise ValueError("No lora injected.")
703
+
704
+ return loras
705
+
706
+ def extract_lora_as_tensor(
707
+ model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
708
+ ):
709
+
710
+ loras = []
711
+
712
+ for _m, _n, _child_module in _find_modules(
713
+ model,
714
+ target_replace_module,
715
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
716
+ ):
717
+ up, down = _child_module.realize_as_lora()
718
+ if as_fp16:
719
+ up = up.to(torch.float16)
720
+ down = down.to(torch.float16)
721
+
722
+ loras.append((up, down))
723
+
724
+ if len(loras) == 0:
725
+ raise ValueError("No lora injected.")
726
+
727
+ return loras
728
+
729
+
730
+ def save_lora_weight(
731
+ model,
732
+ path="./lora.pt",
733
+ target_replace_module=DEFAULT_TARGET_REPLACE,
734
+ flag=None
735
+ ):
736
+ weights = []
737
+ for _up, _down in extract_lora_ups_down(
738
+ model, target_replace_module=target_replace_module
739
+ ):
740
+ weights.append(_up.weight.to("cpu").to(torch.float32))
741
+ weights.append(_down.weight.to("cpu").to(torch.float32))
742
+ if not flag:
743
+ torch.save(weights, path)
744
+ else:
745
+ weights_new=[]
746
+ for i in range(0, len(weights), 4):
747
+ subset = weights[i+(flag-1)*2:i+(flag-1)*2+2]
748
+ weights_new.extend(subset)
749
+ torch.save(weights_new, path)
750
+
751
+ def save_lora_as_json(model, path="./lora.json"):
752
+ weights = []
753
+ for _up, _down in extract_lora_ups_down(model):
754
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
755
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
756
+
757
+ import json
758
+
759
+ with open(path, "w") as f:
760
+ json.dump(weights, f)
761
+
762
+
763
+ def save_safeloras_with_embeds(
764
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
765
+ embeds: Dict[str, torch.Tensor] = {},
766
+ outpath="./lora.safetensors",
767
+ ):
768
+ """
769
+ Saves the Lora from multiple modules in a single safetensor file.
770
+
771
+ modelmap is a dictionary of {
772
+ "module name": (module, target_replace_module)
773
+ }
774
+ """
775
+ weights = {}
776
+ metadata = {}
777
+
778
+ for name, (model, target_replace_module) in modelmap.items():
779
+ metadata[name] = json.dumps(list(target_replace_module))
780
+
781
+ for i, (_up, _down) in enumerate(
782
+ extract_lora_as_tensor(model, target_replace_module)
783
+ ):
784
+ rank = _down.shape[0]
785
+
786
+ metadata[f"{name}:{i}:rank"] = str(rank)
787
+ weights[f"{name}:{i}:up"] = _up
788
+ weights[f"{name}:{i}:down"] = _down
789
+
790
+ for token, tensor in embeds.items():
791
+ metadata[token] = EMBED_FLAG
792
+ weights[token] = tensor
793
+
794
+ print(f"Saving weights to {outpath}")
795
+ safe_save(weights, outpath, metadata)
796
+
797
+
798
+ def save_safeloras(
799
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
800
+ outpath="./lora.safetensors",
801
+ ):
802
+ return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
803
+
804
+
805
+ def convert_loras_to_safeloras_with_embeds(
806
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
807
+ embeds: Dict[str, torch.Tensor] = {},
808
+ outpath="./lora.safetensors",
809
+ ):
810
+ """
811
+ Converts the Lora from multiple pytorch .pt files into a single safetensor file.
812
+
813
+ modelmap is a dictionary of {
814
+ "module name": (pytorch_model_path, target_replace_module, rank)
815
+ }
816
+ """
817
+
818
+ weights = {}
819
+ metadata = {}
820
+
821
+ for name, (path, target_replace_module, r) in modelmap.items():
822
+ metadata[name] = json.dumps(list(target_replace_module))
823
+
824
+ lora = torch.load(path)
825
+ for i, weight in enumerate(lora):
826
+ is_up = i % 2 == 0
827
+ i = i // 2
828
+
829
+ if is_up:
830
+ metadata[f"{name}:{i}:rank"] = str(r)
831
+ weights[f"{name}:{i}:up"] = weight
832
+ else:
833
+ weights[f"{name}:{i}:down"] = weight
834
+
835
+ for token, tensor in embeds.items():
836
+ metadata[token] = EMBED_FLAG
837
+ weights[token] = tensor
838
+
839
+ print(f"Saving weights to {outpath}")
840
+ safe_save(weights, outpath, metadata)
841
+
842
+
843
+ def convert_loras_to_safeloras(
844
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
845
+ outpath="./lora.safetensors",
846
+ ):
847
+ convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
848
+
849
+
850
+ def parse_safeloras(
851
+ safeloras,
852
+ ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
853
+ """
854
+ Converts a loaded safetensor file that contains a set of module Loras
855
+ into Parameters and other information
856
+
857
+ Output is a dictionary of {
858
+ "module name": (
859
+ [list of weights],
860
+ [list of ranks],
861
+ target_replacement_modules
862
+ )
863
+ }
864
+ """
865
+ loras = {}
866
+ metadata = safeloras.metadata()
867
+
868
+ get_name = lambda k: k.split(":")[0]
869
+
870
+ keys = list(safeloras.keys())
871
+ keys.sort(key=get_name)
872
+
873
+ for name, module_keys in groupby(keys, get_name):
874
+ info = metadata.get(name)
875
+
876
+ if not info:
877
+ raise ValueError(
878
+ f"Tensor {name} has no metadata - is this a Lora safetensor?"
879
+ )
880
+
881
+ # Skip Textual Inversion embeds
882
+ if info == EMBED_FLAG:
883
+ continue
884
+
885
+ # Handle Loras
886
+ # Extract the targets
887
+ target = json.loads(info)
888
+
889
+ # Build the result lists - Python needs us to preallocate lists to insert into them
890
+ module_keys = list(module_keys)
891
+ ranks = [4] * (len(module_keys) // 2)
892
+ weights = [None] * len(module_keys)
893
+
894
+ for key in module_keys:
895
+ # Split the model name and index out of the key
896
+ _, idx, direction = key.split(":")
897
+ idx = int(idx)
898
+
899
+ # Add the rank
900
+ ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
901
+
902
+ # Insert the weight into the list
903
+ idx = idx * 2 + (1 if direction == "down" else 0)
904
+ weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
905
+
906
+ loras[name] = (weights, ranks, target)
907
+
908
+ return loras
909
+
910
+
911
+ def parse_safeloras_embeds(
912
+ safeloras,
913
+ ) -> Dict[str, torch.Tensor]:
914
+ """
915
+ Converts a loaded safetensor file that contains Textual Inversion embeds into
916
+ a dictionary of embed_token: Tensor
917
+ """
918
+ embeds = {}
919
+ metadata = safeloras.metadata()
920
+
921
+ for key in safeloras.keys():
922
+ # Only handle Textual Inversion embeds
923
+ meta = metadata.get(key)
924
+ if not meta or meta != EMBED_FLAG:
925
+ continue
926
+
927
+ embeds[key] = safeloras.get_tensor(key)
928
+
929
+ return embeds
930
+
931
+
932
+ def load_safeloras(path, device="cpu"):
933
+ safeloras = safe_open(path, framework="pt", device=device)
934
+ return parse_safeloras(safeloras)
935
+
936
+
937
+ def load_safeloras_embeds(path, device="cpu"):
938
+ safeloras = safe_open(path, framework="pt", device=device)
939
+ return parse_safeloras_embeds(safeloras)
940
+
941
+
942
+ def load_safeloras_both(path, device="cpu"):
943
+ safeloras = safe_open(path, framework="pt", device=device)
944
+ return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
945
+
946
+
947
+ def collapse_lora(model, alpha=1.0):
948
+
949
+ for _module, name, _child_module in _find_modules(
950
+ model,
951
+ UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
952
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
953
+ ):
954
+
955
+ if isinstance(_child_module, LoraInjectedLinear):
956
+ print("Collapsing Lin Lora in", name)
957
+
958
+ _child_module.linear.weight = nn.Parameter(
959
+ _child_module.linear.weight.data
960
+ + alpha
961
+ * (
962
+ _child_module.lora_up.weight.data
963
+ @ _child_module.lora_down.weight.data
964
+ )
965
+ .type(_child_module.linear.weight.dtype)
966
+ .to(_child_module.linear.weight.device)
967
+ )
968
+
969
+ else:
970
+ print("Collapsing Conv Lora in", name)
971
+ _child_module.conv.weight = nn.Parameter(
972
+ _child_module.conv.weight.data
973
+ + alpha
974
+ * (
975
+ _child_module.lora_up.weight.data.flatten(start_dim=1)
976
+ @ _child_module.lora_down.weight.data.flatten(start_dim=1)
977
+ )
978
+ .reshape(_child_module.conv.weight.data.shape)
979
+ .type(_child_module.conv.weight.dtype)
980
+ .to(_child_module.conv.weight.device)
981
+ )
982
+
983
+
984
+ def monkeypatch_or_replace_lora(
985
+ model,
986
+ loras,
987
+ target_replace_module=DEFAULT_TARGET_REPLACE,
988
+ r: Union[int, List[int]] = 4,
989
+ ):
990
+ for _module, name, _child_module in _find_modules(
991
+ model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
992
+ ):
993
+ _source = (
994
+ _child_module.linear
995
+ if isinstance(_child_module, LoraInjectedLinear)
996
+ else _child_module
997
+ )
998
+
999
+ weight = _source.weight
1000
+ bias = _source.bias
1001
+ _tmp = LoraInjectedLinear(
1002
+ _source.in_features,
1003
+ _source.out_features,
1004
+ _source.bias is not None,
1005
+ r=r.pop(0) if isinstance(r, list) else r,
1006
+ )
1007
+ _tmp.linear.weight = weight
1008
+
1009
+ if bias is not None:
1010
+ _tmp.linear.bias = bias
1011
+
1012
+ # switch the module
1013
+ _module._modules[name] = _tmp
1014
+
1015
+ up_weight = loras.pop(0)
1016
+ down_weight = loras.pop(0)
1017
+
1018
+ _module._modules[name].lora_up.weight = nn.Parameter(
1019
+ up_weight.type(weight.dtype)
1020
+ )
1021
+ _module._modules[name].lora_down.weight = nn.Parameter(
1022
+ down_weight.type(weight.dtype)
1023
+ )
1024
+
1025
+ _module._modules[name].to(weight.device)
1026
+
1027
+
1028
+ def monkeypatch_or_replace_lora_extended(
1029
+ model,
1030
+ loras,
1031
+ target_replace_module=DEFAULT_TARGET_REPLACE,
1032
+ r: Union[int, List[int]] = 4,
1033
+ ):
1034
+ for _module, name, _child_module in _find_modules(
1035
+ model,
1036
+ target_replace_module,
1037
+ search_class=[
1038
+ nn.Linear,
1039
+ nn.Conv2d,
1040
+ nn.Conv3d,
1041
+ LoraInjectedLinear,
1042
+ LoraInjectedConv2d,
1043
+ LoraInjectedConv3d,
1044
+ ],
1045
+ ):
1046
+
1047
+ if (_child_module.__class__ == nn.Linear) or (
1048
+ _child_module.__class__ == LoraInjectedLinear
1049
+ ):
1050
+ if len(loras[0].shape) != 2:
1051
+ continue
1052
+
1053
+ _source = (
1054
+ _child_module.linear
1055
+ if isinstance(_child_module, LoraInjectedLinear)
1056
+ else _child_module
1057
+ )
1058
+
1059
+ weight = _source.weight
1060
+ bias = _source.bias
1061
+ _tmp = LoraInjectedLinear(
1062
+ _source.in_features,
1063
+ _source.out_features,
1064
+ _source.bias is not None,
1065
+ r=r.pop(0) if isinstance(r, list) else r,
1066
+ )
1067
+ _tmp.linear.weight = weight
1068
+
1069
+ if bias is not None:
1070
+ _tmp.linear.bias = bias
1071
+
1072
+ elif (_child_module.__class__ == nn.Conv2d) or (
1073
+ _child_module.__class__ == LoraInjectedConv2d
1074
+ ):
1075
+ if len(loras[0].shape) != 4:
1076
+ continue
1077
+ _source = (
1078
+ _child_module.conv
1079
+ if isinstance(_child_module, LoraInjectedConv2d)
1080
+ else _child_module
1081
+ )
1082
+
1083
+ weight = _source.weight
1084
+ bias = _source.bias
1085
+ _tmp = LoraInjectedConv2d(
1086
+ _source.in_channels,
1087
+ _source.out_channels,
1088
+ _source.kernel_size,
1089
+ _source.stride,
1090
+ _source.padding,
1091
+ _source.dilation,
1092
+ _source.groups,
1093
+ _source.bias is not None,
1094
+ r=r.pop(0) if isinstance(r, list) else r,
1095
+ )
1096
+
1097
+ _tmp.conv.weight = weight
1098
+
1099
+ if bias is not None:
1100
+ _tmp.conv.bias = bias
1101
+
1102
+ elif _child_module.__class__ == nn.Conv3d or(
1103
+ _child_module.__class__ == LoraInjectedConv3d
1104
+ ):
1105
+
1106
+ if len(loras[0].shape) != 5:
1107
+ continue
1108
+
1109
+ _source = (
1110
+ _child_module.conv
1111
+ if isinstance(_child_module, LoraInjectedConv3d)
1112
+ else _child_module
1113
+ )
1114
+
1115
+ weight = _source.weight
1116
+ bias = _source.bias
1117
+ _tmp = LoraInjectedConv3d(
1118
+ _source.in_channels,
1119
+ _source.out_channels,
1120
+ bias=_source.bias is not None,
1121
+ kernel_size=_source.kernel_size,
1122
+ padding=_source.padding,
1123
+ r=r.pop(0) if isinstance(r, list) else r,
1124
+ )
1125
+
1126
+ _tmp.conv.weight = weight
1127
+
1128
+ if bias is not None:
1129
+ _tmp.conv.bias = bias
1130
+
1131
+ # switch the module
1132
+ _module._modules[name] = _tmp
1133
+
1134
+ up_weight = loras.pop(0)
1135
+ down_weight = loras.pop(0)
1136
+
1137
+ _module._modules[name].lora_up.weight = nn.Parameter(
1138
+ up_weight.type(weight.dtype)
1139
+ )
1140
+ _module._modules[name].lora_down.weight = nn.Parameter(
1141
+ down_weight.type(weight.dtype)
1142
+ )
1143
+
1144
+ _module._modules[name].to(weight.device)
1145
+
1146
+
1147
+ def monkeypatch_or_replace_safeloras(models, safeloras):
1148
+ loras = parse_safeloras(safeloras)
1149
+
1150
+ for name, (lora, ranks, target) in loras.items():
1151
+ model = getattr(models, name, None)
1152
+
1153
+ if not model:
1154
+ print(f"No model provided for {name}, contained in Lora")
1155
+ continue
1156
+
1157
+ monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
1158
+
1159
+
1160
+ def monkeypatch_remove_lora(model):
1161
+ for _module, name, _child_module in _find_modules(
1162
+ model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
1163
+ ):
1164
+ if isinstance(_child_module, LoraInjectedLinear):
1165
+ _source = _child_module.linear
1166
+ weight, bias = _source.weight, _source.bias
1167
+
1168
+ _tmp = nn.Linear(
1169
+ _source.in_features, _source.out_features, bias is not None
1170
+ )
1171
+
1172
+ _tmp.weight = weight
1173
+ if bias is not None:
1174
+ _tmp.bias = bias
1175
+
1176
+ else:
1177
+ _source = _child_module.conv
1178
+ weight, bias = _source.weight, _source.bias
1179
+
1180
+ if isinstance(_source, nn.Conv2d):
1181
+ _tmp = nn.Conv2d(
1182
+ in_channels=_source.in_channels,
1183
+ out_channels=_source.out_channels,
1184
+ kernel_size=_source.kernel_size,
1185
+ stride=_source.stride,
1186
+ padding=_source.padding,
1187
+ dilation=_source.dilation,
1188
+ groups=_source.groups,
1189
+ bias=bias is not None,
1190
+ )
1191
+
1192
+ _tmp.weight = weight
1193
+ if bias is not None:
1194
+ _tmp.bias = bias
1195
+
1196
+ if isinstance(_source, nn.Conv3d):
1197
+ _tmp = nn.Conv3d(
1198
+ _source.in_channels,
1199
+ _source.out_channels,
1200
+ bias=_source.bias is not None,
1201
+ kernel_size=_source.kernel_size,
1202
+ padding=_source.padding,
1203
+ )
1204
+
1205
+ _tmp.weight = weight
1206
+ if bias is not None:
1207
+ _tmp.bias = bias
1208
+
1209
+ _module._modules[name] = _tmp
1210
+
1211
+
1212
+ def monkeypatch_add_lora(
1213
+ model,
1214
+ loras,
1215
+ target_replace_module=DEFAULT_TARGET_REPLACE,
1216
+ alpha: float = 1.0,
1217
+ beta: float = 1.0,
1218
+ ):
1219
+ for _module, name, _child_module in _find_modules(
1220
+ model, target_replace_module, search_class=[LoraInjectedLinear]
1221
+ ):
1222
+ weight = _child_module.linear.weight
1223
+
1224
+ up_weight = loras.pop(0)
1225
+ down_weight = loras.pop(0)
1226
+
1227
+ _module._modules[name].lora_up.weight = nn.Parameter(
1228
+ up_weight.type(weight.dtype).to(weight.device) * alpha
1229
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
1230
+ )
1231
+ _module._modules[name].lora_down.weight = nn.Parameter(
1232
+ down_weight.type(weight.dtype).to(weight.device) * alpha
1233
+ + _module._modules[name].lora_down.weight.to(weight.device) * beta
1234
+ )
1235
+
1236
+ _module._modules[name].to(weight.device)
1237
+
1238
+
1239
+ def tune_lora_scale(model, alpha: float = 1.0):
1240
+ for _module in model.modules():
1241
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1242
+ _module.scale = alpha
1243
+
1244
+
1245
+ def set_lora_diag(model, diag: torch.Tensor):
1246
+ for _module in model.modules():
1247
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1248
+ _module.set_selector_from_diag(diag)
1249
+
1250
+
1251
+ def _text_lora_path(path: str) -> str:
1252
+ assert path.endswith(".pt"), "Only .pt files are supported"
1253
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
1254
+
1255
+
1256
+ def _ti_lora_path(path: str) -> str:
1257
+ assert path.endswith(".pt"), "Only .pt files are supported"
1258
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
1259
+
1260
+
1261
+ def apply_learned_embed_in_clip(
1262
+ learned_embeds,
1263
+ text_encoder,
1264
+ tokenizer,
1265
+ token: Optional[Union[str, List[str]]] = None,
1266
+ idempotent=False,
1267
+ ):
1268
+ if isinstance(token, str):
1269
+ trained_tokens = [token]
1270
+ elif isinstance(token, list):
1271
+ assert len(learned_embeds.keys()) == len(
1272
+ token
1273
+ ), "The number of tokens and the number of embeds should be the same"
1274
+ trained_tokens = token
1275
+ else:
1276
+ trained_tokens = list(learned_embeds.keys())
1277
+
1278
+ for token in trained_tokens:
1279
+ print(token)
1280
+ embeds = learned_embeds[token]
1281
+
1282
+ # cast to dtype of text_encoder
1283
+ dtype = text_encoder.get_input_embeddings().weight.dtype
1284
+ num_added_tokens = tokenizer.add_tokens(token)
1285
+
1286
+ i = 1
1287
+ if not idempotent:
1288
+ while num_added_tokens == 0:
1289
+ print(f"The tokenizer already contains the token {token}.")
1290
+ token = f"{token[:-1]}-{i}>"
1291
+ print(f"Attempting to add the token {token}.")
1292
+ num_added_tokens = tokenizer.add_tokens(token)
1293
+ i += 1
1294
+ elif num_added_tokens == 0 and idempotent:
1295
+ print(f"The tokenizer already contains the token {token}.")
1296
+ print(f"Replacing {token} embedding.")
1297
+
1298
+ # resize the token embeddings
1299
+ text_encoder.resize_token_embeddings(len(tokenizer))
1300
+
1301
+ # get the id for the token and assign the embeds
1302
+ token_id = tokenizer.convert_tokens_to_ids(token)
1303
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
1304
+ return token
1305
+
1306
+
1307
+ def load_learned_embed_in_clip(
1308
+ learned_embeds_path,
1309
+ text_encoder,
1310
+ tokenizer,
1311
+ token: Optional[Union[str, List[str]]] = None,
1312
+ idempotent=False,
1313
+ ):
1314
+ learned_embeds = torch.load(learned_embeds_path)
1315
+ apply_learned_embed_in_clip(
1316
+ learned_embeds, text_encoder, tokenizer, token, idempotent
1317
+ )
1318
+
1319
+
1320
+ def patch_pipe(
1321
+ pipe,
1322
+ maybe_unet_path,
1323
+ token: Optional[str] = None,
1324
+ r: int = 4,
1325
+ patch_unet=True,
1326
+ patch_text=True,
1327
+ patch_ti=True,
1328
+ idempotent_token=True,
1329
+ unet_target_replace_module=DEFAULT_TARGET_REPLACE,
1330
+ text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1331
+ ):
1332
+ if maybe_unet_path.endswith(".pt"):
1333
+ # torch format
1334
+
1335
+ if maybe_unet_path.endswith(".ti.pt"):
1336
+ unet_path = maybe_unet_path[:-6] + ".pt"
1337
+ elif maybe_unet_path.endswith(".text_encoder.pt"):
1338
+ unet_path = maybe_unet_path[:-16] + ".pt"
1339
+ else:
1340
+ unet_path = maybe_unet_path
1341
+
1342
+ ti_path = _ti_lora_path(unet_path)
1343
+ text_path = _text_lora_path(unet_path)
1344
+
1345
+ if patch_unet:
1346
+ print("LoRA : Patching Unet")
1347
+ monkeypatch_or_replace_lora(
1348
+ pipe.unet,
1349
+ torch.load(unet_path),
1350
+ r=r,
1351
+ target_replace_module=unet_target_replace_module,
1352
+ )
1353
+
1354
+ if patch_text:
1355
+ print("LoRA : Patching text encoder")
1356
+ monkeypatch_or_replace_lora(
1357
+ pipe.text_encoder,
1358
+ torch.load(text_path),
1359
+ target_replace_module=text_target_replace_module,
1360
+ r=r,
1361
+ )
1362
+ if patch_ti:
1363
+ print("LoRA : Patching token input")
1364
+ token = load_learned_embed_in_clip(
1365
+ ti_path,
1366
+ pipe.text_encoder,
1367
+ pipe.tokenizer,
1368
+ token=token,
1369
+ idempotent=idempotent_token,
1370
+ )
1371
+
1372
+ elif maybe_unet_path.endswith(".safetensors"):
1373
+ safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1374
+ monkeypatch_or_replace_safeloras(pipe, safeloras)
1375
+ tok_dict = parse_safeloras_embeds(safeloras)
1376
+ if patch_ti:
1377
+ apply_learned_embed_in_clip(
1378
+ tok_dict,
1379
+ pipe.text_encoder,
1380
+ pipe.tokenizer,
1381
+ token=token,
1382
+ idempotent=idempotent_token,
1383
+ )
1384
+ return tok_dict
1385
+
1386
+
1387
+ def train_patch_pipe(pipe, patch_unet, patch_text):
1388
+ if patch_unet:
1389
+ print("LoRA : Patching Unet")
1390
+ collapse_lora(pipe.unet)
1391
+ monkeypatch_remove_lora(pipe.unet)
1392
+
1393
+ if patch_text:
1394
+ print("LoRA : Patching text encoder")
1395
+
1396
+ collapse_lora(pipe.text_encoder)
1397
+ monkeypatch_remove_lora(pipe.text_encoder)
1398
+
1399
+ @torch.no_grad()
1400
+ def inspect_lora(model):
1401
+ moved = {}
1402
+
1403
+ for name, _module in model.named_modules():
1404
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1405
+ ups = _module.lora_up.weight.data.clone()
1406
+ downs = _module.lora_down.weight.data.clone()
1407
+
1408
+ wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1409
+
1410
+ dist = wght.flatten().abs().mean().item()
1411
+ if name in moved:
1412
+ moved[name].append(dist)
1413
+ else:
1414
+ moved[name] = [dist]
1415
+
1416
+ return moved
1417
+
1418
+
1419
+ def save_all(
1420
+ unet,
1421
+ text_encoder,
1422
+ save_path,
1423
+ placeholder_token_ids=None,
1424
+ placeholder_tokens=None,
1425
+ save_lora=True,
1426
+ save_ti=True,
1427
+ target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1428
+ target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1429
+ safe_form=True,
1430
+ ):
1431
+ if not safe_form:
1432
+ # save ti
1433
+ if save_ti:
1434
+ ti_path = _ti_lora_path(save_path)
1435
+ learned_embeds_dict = {}
1436
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1437
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1438
+ print(
1439
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1440
+ learned_embeds[:4],
1441
+ )
1442
+ learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1443
+
1444
+ torch.save(learned_embeds_dict, ti_path)
1445
+ print("Ti saved to ", ti_path)
1446
+
1447
+ # save text encoder
1448
+ if save_lora:
1449
+ save_lora_weight(
1450
+ unet, save_path, target_replace_module=target_replace_module_unet
1451
+ )
1452
+ print("Unet saved to ", save_path)
1453
+
1454
+ save_lora_weight(
1455
+ text_encoder,
1456
+ _text_lora_path(save_path),
1457
+ target_replace_module=target_replace_module_text,
1458
+ )
1459
+ print("Text Encoder saved to ", _text_lora_path(save_path))
1460
+
1461
+ else:
1462
+ assert save_path.endswith(
1463
+ ".safetensors"
1464
+ ), f"Save path : {save_path} should end with .safetensors"
1465
+
1466
+ loras = {}
1467
+ embeds = {}
1468
+
1469
+ if save_lora:
1470
+
1471
+ loras["unet"] = (unet, target_replace_module_unet)
1472
+ loras["text_encoder"] = (text_encoder, target_replace_module_text)
1473
+
1474
+ if save_ti:
1475
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1476
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1477
+ print(
1478
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1479
+ learned_embeds[:4],
1480
+ )
1481
+ embeds[tok] = learned_embeds.detach().cpu()
1482
+
1483
+ save_safeloras_with_embeds(loras, embeds, save_path)
utils/lora_handler.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from logging import warnings
3
+ import torch
4
+ from typing import Union
5
+ from types import SimpleNamespace
6
+ from models.unet_3d_condition import UNet3DConditionModel
7
+ from transformers import CLIPTextModel
8
+ from utils.convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20
9
+
10
+ from .lora import (
11
+ extract_lora_ups_down,
12
+ inject_trainable_lora_extended,
13
+ save_lora_weight,
14
+ train_patch_pipe,
15
+ monkeypatch_or_replace_lora,
16
+ monkeypatch_or_replace_lora_extended
17
+ )
18
+
19
+
20
+ FILE_BASENAMES = ['unet', 'text_encoder']
21
+ LORA_FILE_TYPES = ['.pt', '.safetensors']
22
+ CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r']
23
+ STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias']
24
+
25
+ lora_versions = dict(
26
+ stable_lora = "stable_lora",
27
+ cloneofsimo = "cloneofsimo"
28
+ )
29
+
30
+ lora_func_types = dict(
31
+ loader = "loader",
32
+ injector = "injector"
33
+ )
34
+
35
+ lora_args = dict(
36
+ model = None,
37
+ loras = None,
38
+ target_replace_module = [],
39
+ target_module = [],
40
+ r = 4,
41
+ search_class = [torch.nn.Linear],
42
+ dropout = 0,
43
+ lora_bias = 'none'
44
+ )
45
+
46
+ LoraVersions = SimpleNamespace(**lora_versions)
47
+ LoraFuncTypes = SimpleNamespace(**lora_func_types)
48
+
49
+ LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
50
+ LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
51
+
52
+ def filter_dict(_dict, keys=[]):
53
+ if len(keys) == 0:
54
+ assert "Keys cannot empty for filtering return dict."
55
+
56
+ for k in keys:
57
+ if k not in lora_args.keys():
58
+ assert f"{k} does not exist in available LoRA arguments"
59
+
60
+ return {k: v for k, v in _dict.items() if k in keys}
61
+
62
+ class LoraHandler(object):
63
+ def __init__(
64
+ self,
65
+ version: LORA_VERSIONS = LoraVersions.cloneofsimo,
66
+ use_unet_lora: bool = False,
67
+ use_text_lora: bool = False,
68
+ save_for_webui: bool = False,
69
+ only_for_webui: bool = False,
70
+ lora_bias: str = 'none',
71
+ unet_replace_modules: list = None,
72
+ text_encoder_replace_modules: list = None
73
+ ):
74
+ self.version = version
75
+ self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
76
+ self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
77
+ self.lora_bias = lora_bias
78
+ self.use_unet_lora = use_unet_lora
79
+ self.use_text_lora = use_text_lora
80
+ self.save_for_webui = save_for_webui
81
+ self.only_for_webui = only_for_webui
82
+ self.unet_replace_modules = unet_replace_modules
83
+ self.text_encoder_replace_modules = text_encoder_replace_modules
84
+ self.use_lora = any([use_text_lora, use_unet_lora])
85
+
86
+ def is_cloneofsimo_lora(self):
87
+ return self.version == LoraVersions.cloneofsimo
88
+
89
+
90
+ def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader):
91
+
92
+ if self.is_cloneofsimo_lora():
93
+
94
+ if func_type == LoraFuncTypes.loader:
95
+ return monkeypatch_or_replace_lora_extended
96
+
97
+ if func_type == LoraFuncTypes.injector:
98
+ return inject_trainable_lora_extended
99
+
100
+ assert "LoRA Version does not exist."
101
+
102
+ def check_lora_ext(self, lora_file: str):
103
+ return lora_file.endswith(tuple(LORA_FILE_TYPES))
104
+
105
+ def get_lora_file_path(
106
+ self,
107
+ lora_path: str,
108
+ model: Union[UNet3DConditionModel, CLIPTextModel]
109
+ ):
110
+ if os.path.exists(lora_path):
111
+ lora_filenames = [fns for fns in os.listdir(lora_path)]
112
+ is_lora = self.check_lora_ext(lora_path)
113
+
114
+ is_unet = isinstance(model, UNet3DConditionModel)
115
+ is_text = isinstance(model, CLIPTextModel)
116
+ idx = 0 if is_unet else 1
117
+
118
+ base_name = FILE_BASENAMES[idx]
119
+
120
+ for lora_filename in lora_filenames:
121
+ is_lora = self.check_lora_ext(lora_filename)
122
+ if not is_lora:
123
+ continue
124
+
125
+ if base_name in lora_filename:
126
+ return os.path.join(lora_path, lora_filename)
127
+
128
+ return None
129
+
130
+ def handle_lora_load(self, file_name:str, lora_loader_args: dict = None):
131
+ self.lora_loader(**lora_loader_args)
132
+ print(f"Successfully loaded LoRA from: {file_name}")
133
+
134
+ def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,):
135
+ try:
136
+ lora_file = self.get_lora_file_path(lora_path, model)
137
+
138
+ if lora_file is not None:
139
+ lora_loader_args.update({"lora_path": lora_file})
140
+ self.handle_lora_load(lora_file, lora_loader_args)
141
+
142
+ else:
143
+ print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...")
144
+
145
+ except Exception as e:
146
+ print(f"An error occured while loading a LoRA file: {e}")
147
+
148
+ def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale):
149
+ return_dict = lora_args.copy()
150
+
151
+ if self.is_cloneofsimo_lora():
152
+ return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
153
+ return_dict.update({
154
+ "model": model,
155
+ "loras": self.get_lora_file_path(lora_path, model),
156
+ "target_replace_module": replace_modules,
157
+ "r": r,
158
+ "scale": scale,
159
+ "dropout_p": dropout,
160
+ })
161
+
162
+ return return_dict
163
+
164
+ def do_lora_injection(
165
+ self,
166
+ model,
167
+ replace_modules,
168
+ bias='none',
169
+ dropout=0,
170
+ r=4,
171
+ lora_loader_args=None,
172
+ ):
173
+ REPLACE_MODULES = replace_modules
174
+
175
+ params = None
176
+ negation = None
177
+ is_injection_hybrid = False
178
+
179
+ if self.is_cloneofsimo_lora():
180
+ is_injection_hybrid = True
181
+ injector_args = lora_loader_args
182
+
183
+ params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended
184
+ for _up, _down in extract_lora_ups_down(
185
+ model,
186
+ target_replace_module=REPLACE_MODULES):
187
+
188
+ if all(x is not None for x in [_up, _down]):
189
+ print(f"Lora successfully injected into {model.__class__.__name__}.")
190
+
191
+ break
192
+
193
+ return params, negation, is_injection_hybrid
194
+
195
+ return params, negation, is_injection_hybrid
196
+
197
+ def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0):
198
+
199
+ params = None
200
+ negation = None
201
+
202
+ lora_loader_args = self.get_lora_func_args(
203
+ lora_path,
204
+ use_lora,
205
+ model,
206
+ replace_modules,
207
+ r,
208
+ dropout,
209
+ self.lora_bias,
210
+ scale
211
+ )
212
+
213
+ if use_lora:
214
+ params, negation, is_injection_hybrid = self.do_lora_injection(
215
+ model,
216
+ replace_modules,
217
+ bias=self.lora_bias,
218
+ lora_loader_args=lora_loader_args,
219
+ dropout=dropout,
220
+ r=r
221
+ )
222
+
223
+ if not is_injection_hybrid:
224
+ self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args)
225
+
226
+ params = model if params is None else params
227
+ return params, negation
228
+
229
+ def save_cloneofsimo_lora(self, model, save_path, step, flag):
230
+
231
+ def save_lora(model, name, condition, replace_modules, step, save_path, flag=None):
232
+ if condition and replace_modules is not None:
233
+ save_path = f"{save_path}/{step}_{name}.pt"
234
+ save_lora_weight(model, save_path, replace_modules, flag)
235
+
236
+ save_lora(
237
+ model.unet,
238
+ FILE_BASENAMES[0],
239
+ self.use_unet_lora,
240
+ self.unet_replace_modules,
241
+ step,
242
+ save_path,
243
+ flag
244
+ )
245
+ save_lora(
246
+ model.text_encoder,
247
+ FILE_BASENAMES[1],
248
+ self.use_text_lora,
249
+ self.text_encoder_replace_modules,
250
+ step,
251
+ save_path,
252
+ flag
253
+ )
254
+
255
+ # train_patch_pipe(model, self.use_unet_lora, self.use_text_lora)
256
+
257
+ def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None):
258
+ save_path = f"{save_path}/lora"
259
+ os.makedirs(save_path, exist_ok=True)
260
+
261
+ if self.is_cloneofsimo_lora():
262
+ if any([self.save_for_webui, self.only_for_webui]):
263
+ warnings.warn(
264
+ """
265
+ You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention.
266
+ Only 'stable_lora' is supported for saving to a compatible webui file.
267
+ """
268
+ )
269
+ self.save_cloneofsimo_lora(model, save_path, step, flag)