GrayShine commited on
Commit
2e5e07d
·
verified ·
1 Parent(s): 83d94d1

Upload 60 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. LICENSE +201 -0
  3. configs/vlog_read_script_sample.yaml +39 -0
  4. configs/vlog_write_script.yaml +3 -0
  5. configs/with_mask_ref_sample.yaml +36 -0
  6. configs/with_mask_sample.yaml +33 -0
  7. datasets/__pycache__/video_transforms.cpython-310.pyc +0 -0
  8. datasets/video_transforms.py +382 -0
  9. diffusion/__init__.py +47 -0
  10. diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  11. diffusion/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
  12. diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc +0 -0
  13. diffusion/__pycache__/respace.cpython-310.pyc +0 -0
  14. diffusion/diffusion_utils.py +88 -0
  15. diffusion/gaussian_diffusion.py +931 -0
  16. diffusion/respace.py +130 -0
  17. diffusion/timestep_sampler.py +150 -0
  18. input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png +3 -0
  19. input/i2v/A_fish_swims_past_an_oriental_woman.png +3 -0
  20. input/i2v/Cinematic_photograph_View_of_piloting_aaero.png +3 -0
  21. input/i2v/Planet_hits_earth.png +3 -0
  22. input/i2v/Underwater_environment_cosmetic_bottles.png +3 -0
  23. models/__init__.py +33 -0
  24. models/__pycache__/__init__.cpython-310.pyc +0 -0
  25. models/__pycache__/attention.cpython-310.pyc +0 -0
  26. models/__pycache__/clip.cpython-310.pyc +0 -0
  27. models/__pycache__/resnet.cpython-310.pyc +0 -0
  28. models/__pycache__/unet.cpython-310.pyc +0 -0
  29. models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  30. models/attention.py +966 -0
  31. models/clip.py +123 -0
  32. models/resnet.py +212 -0
  33. models/unet.py +699 -0
  34. models/unet_blocks.py +650 -0
  35. models/utils.py +215 -0
  36. requirements.txt +25 -0
  37. results/mask_no_ref/Planet_hits_earth..mp4 +0 -0
  38. results/mask_ref/Planet_hits_earth..mp4 +0 -0
  39. results/vlog/teddy_travel/ref_img/teddy.jpg +0 -0
  40. results/vlog/teddy_travel/script/protagonist_place_reference.txt +0 -0
  41. results/vlog/teddy_travel/script/protagonists_places.txt +22 -0
  42. results/vlog/teddy_travel/script/time_scripts.txt +94 -0
  43. results/vlog/teddy_travel/script/video_prompts.txt +0 -0
  44. results/vlog/teddy_travel/script/zh_video_prompts.txt +95 -0
  45. results/vlog/teddy_travel/story.txt +1 -0
  46. results/vlog/teddy_travel_/story.txt +1 -0
  47. sample_scripts/vlog_read_script_sample.py +303 -0
  48. sample_scripts/vlog_write_script.py +91 -0
  49. sample_scripts/with_mask_ref_sample.py +275 -0
  50. sample_scripts/with_mask_sample.py +249 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png filter=lfs diff=lfs merge=lfs -text
37
+ input/i2v/A_fish_swims_past_an_oriental_woman.png filter=lfs diff=lfs merge=lfs -text
38
+ input/i2v/Cinematic_photograph_View_of_piloting_aaero.png filter=lfs diff=lfs merge=lfs -text
39
+ input/i2v/Planet_hits_earth.png filter=lfs diff=lfs merge=lfs -text
40
+ input/i2v/Underwater_environment_cosmetic_bottles.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
configs/vlog_read_script_sample.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ ckpt: "pretrained/ShowMaker.pt"
3
+ pretrained_model_path: "pretrained/stable-diffusion-v1-4/"
4
+ image_encoder_path: "pretrained/OpenCLIP-ViT-H-14"
5
+ save_path: "results/vlog/teddy_travel/video"
6
+
7
+ # script path
8
+ reference_image_path: ["results/vlog/teddy_travel/ref_img/teddy.jpg"]
9
+ script_file_path: "results/vlog/teddy_travel/script/video_prompts.txt"
10
+ zh_script_file_path: "results/vlog/teddy_travel/script/zh_video_prompts.txt"
11
+ protagonist_file_path: "results/vlog/teddy_travel/script/protagonists_places.txt"
12
+ reference_file_path: "results/vlog/teddy_travel/script/protagonist_place_reference.txt"
13
+ time_file_path: "results/vlog/teddy_travel/script/time_scripts.txt"
14
+ video_transition: False
15
+
16
+ # model config:
17
+ model: UNet
18
+ num_frames: 16
19
+ image_size: [320, 512]
20
+ negative_prompt: "white background"
21
+
22
+ # sample config:
23
+ ref_cfg_scale: 0.3
24
+ seed: 3407
25
+ guidance_scale: 7.5
26
+ cfg_scale: 8.0
27
+ sample_method: 'ddim'
28
+ num_sampling_steps: 100
29
+ researve_frame: 3
30
+ mask_type: "first3"
31
+ use_mask: True
32
+ use_fp16: True
33
+ enable_xformers_memory_efficient_attention: True
34
+ do_classifier_free_guidance: True
35
+ fps: 8
36
+ sample_num:
37
+
38
+ # model speedup
39
+ use_compile: False
configs/vlog_write_script.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # script path
2
+ story_path: "./results/vlog/teddy_travel_/story.txt"
3
+ only_one_protagonist: False
configs/with_mask_ref_sample.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path config:
2
+ ckpt: "pretrained/ShowMaker.pt"
3
+ pretrained_model_path: "pretrained/stable-diffusion-v1-4/"
4
+ image_encoder_path: "pretrained/OpenCLIP-ViT-H-14"
5
+ input_path: 'input/i2v/Planet_hits_earth.png'
6
+ ref_path: 'input/i2v/Planet_hits_earth.png'
7
+ save_path: "results/mask_ref/"
8
+
9
+ # model config:
10
+ model: UNet
11
+ num_frames: 16
12
+ # image_size: [320, 512]
13
+ image_size: [240, 560]
14
+
15
+ # model speedup
16
+ use_fp16: True
17
+ enable_xformers_memory_efficient_attention: True
18
+
19
+ # sample config:
20
+ seed: 3407
21
+ cfg_scale: 8.0
22
+ ref_cfg_scale: 0.5
23
+ sample_method: 'ddim'
24
+ num_sampling_steps: 100
25
+ text_prompt: [
26
+ # "Cinematic photograph. View of piloting aaero.",
27
+ # "A fish swims past an oriental woman.",
28
+ # "A big drop of water falls on a rose petal.",
29
+ # "Underwater environment cosmetic bottles.".
30
+ "Planet hits earth.",
31
+ ]
32
+ additional_prompt: ""
33
+ negative_prompt: ""
34
+ do_classifier_free_guidance: True
35
+ mask_type: "first1"
36
+ use_mask: True
configs/with_mask_sample.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path config:
2
+ ckpt: "pretrained/ShowMaker.pt"
3
+ pretrained_model_path: "pretrained/OpenCLIP-ViT-H-14"
4
+ input_path: 'input/i2v/Planet_hits_earth.png'
5
+ save_path: "results/mask_no_ref/"
6
+
7
+ # model config:
8
+ model: UNet
9
+ num_frames: 16
10
+ # image_size: [320, 512]
11
+ image_size: [240, 560]
12
+
13
+ # model speedup
14
+ use_fp16: True
15
+ enable_xformers_memory_efficient_attention: True
16
+
17
+ # sample config:
18
+ seed: 3407
19
+ cfg_scale: 8.0
20
+ sample_method: 'ddim'
21
+ num_sampling_steps: 100
22
+ text_prompt: [
23
+ # "Cinematic photograph. View of piloting aaero.",
24
+ # "A fish swims past an oriental woman.",
25
+ # "A big drop of water falls on a rose petal.",
26
+ # "Underwater environment cosmetic bottles.".
27
+ "Planet hits earth.",
28
+ ]
29
+ additional_prompt: ""
30
+ negative_prompt: ""
31
+ do_classifier_free_guidance: True
32
+ mask_type: "first1"
33
+ use_mask: True
datasets/__pycache__/video_transforms.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
datasets/video_transforms.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+ from PIL import Image
6
+
7
+ def _is_tensor_video_clip(clip):
8
+ if not torch.is_tensor(clip):
9
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
10
+
11
+ if not clip.ndimension() == 4:
12
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
13
+
14
+ return True
15
+
16
+
17
+ def center_crop_arr(pil_image, image_size):
18
+ """
19
+ Center cropping implementation from ADM.
20
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
21
+ """
22
+ while min(*pil_image.size) >= 2 * image_size:
23
+ pil_image = pil_image.resize(
24
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
25
+ )
26
+
27
+ scale = image_size / min(*pil_image.size)
28
+ pil_image = pil_image.resize(
29
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
30
+ )
31
+
32
+ arr = np.array(pil_image)
33
+ crop_y = (arr.shape[0] - image_size) // 2
34
+ crop_x = (arr.shape[1] - image_size) // 2
35
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
36
+
37
+
38
+ def crop(clip, i, j, h, w):
39
+ """
40
+ Args:
41
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
42
+ """
43
+ if len(clip.size()) != 4:
44
+ raise ValueError("clip should be a 4D tensor")
45
+ return clip[..., i : i + h, j : j + w]
46
+
47
+
48
+ def resize(clip, target_size, interpolation_mode):
49
+ if len(target_size) != 2:
50
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
51
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
52
+
53
+ def resize_scale(clip, target_size, interpolation_mode):
54
+ if len(target_size) != 2:
55
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
56
+ H, W = clip.size(-2), clip.size(-1)
57
+ scale_ = target_size[0] / min(H, W)
58
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
59
+
60
+ def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
61
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
62
+
63
+ def resize_scale_with_height(clip, target_size, interpolation_mode):
64
+ H, W = clip.size(-2), clip.size(-1)
65
+ scale_ = target_size / H
66
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
67
+
68
+ def resize_scale_with_weight(clip, target_size, interpolation_mode):
69
+ H, W = clip.size(-2), clip.size(-1)
70
+ scale_ = target_size / W
71
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
72
+
73
+
74
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
75
+ """
76
+ Do spatial cropping and resizing to the video clip
77
+ Args:
78
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
79
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
80
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
81
+ h (int): Height of the cropped region.
82
+ w (int): Width of the cropped region.
83
+ size (tuple(int, int)): height and width of resized clip
84
+ Returns:
85
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
86
+ """
87
+ if not _is_tensor_video_clip(clip):
88
+ raise ValueError("clip should be a 4D torch.tensor")
89
+ clip = crop(clip, i, j, h, w)
90
+ clip = resize(clip, size, interpolation_mode)
91
+ return clip
92
+
93
+
94
+ def center_crop(clip, crop_size):
95
+ if not _is_tensor_video_clip(clip):
96
+ raise ValueError("clip should be a 4D torch.tensor")
97
+ h, w = clip.size(-2), clip.size(-1)
98
+ # print(clip.shape)
99
+ th, tw = crop_size
100
+ if h < th or w < tw:
101
+ # print(h, w)
102
+ raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
103
+
104
+ i = int(round((h - th) / 2.0))
105
+ j = int(round((w - tw) / 2.0))
106
+ return crop(clip, i, j, th, tw)
107
+
108
+
109
+ def center_crop_using_short_edge(clip):
110
+ if not _is_tensor_video_clip(clip):
111
+ raise ValueError("clip should be a 4D torch.tensor")
112
+ h, w = clip.size(-2), clip.size(-1)
113
+ if h < w:
114
+ th, tw = h, h
115
+ i = 0
116
+ j = int(round((w - tw) / 2.0))
117
+ else:
118
+ th, tw = w, w
119
+ i = int(round((h - th) / 2.0))
120
+ j = 0
121
+ return crop(clip, i, j, th, tw)
122
+
123
+
124
+ def random_shift_crop(clip):
125
+ '''
126
+ Slide along the long edge, with the short edge as crop size
127
+ '''
128
+ if not _is_tensor_video_clip(clip):
129
+ raise ValueError("clip should be a 4D torch.tensor")
130
+ h, w = clip.size(-2), clip.size(-1)
131
+
132
+ if h <= w:
133
+ long_edge = w
134
+ short_edge = h
135
+ else:
136
+ long_edge = h
137
+ short_edge =w
138
+
139
+ th, tw = short_edge, short_edge
140
+
141
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
142
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
143
+ return crop(clip, i, j, th, tw)
144
+
145
+
146
+ def to_tensor(clip):
147
+ """
148
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
149
+ permute the dimensions of clip tensor
150
+ Args:
151
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
152
+ Return:
153
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
154
+ """
155
+ _is_tensor_video_clip(clip)
156
+ if not clip.dtype == torch.uint8:
157
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
158
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
159
+ return clip.float() / 255.0
160
+
161
+
162
+ def normalize(clip, mean, std, inplace=False):
163
+ """
164
+ Args:
165
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
166
+ mean (tuple): pixel RGB mean. Size is (3)
167
+ std (tuple): pixel standard deviation. Size is (3)
168
+ Returns:
169
+ normalized clip (torch.tensor): Size is (T, C, H, W)
170
+ """
171
+ if not _is_tensor_video_clip(clip):
172
+ raise ValueError("clip should be a 4D torch.tensor")
173
+ if not inplace:
174
+ clip = clip.clone()
175
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
176
+ # print(mean)
177
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
178
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
179
+ return clip
180
+
181
+
182
+ def hflip(clip):
183
+ """
184
+ Args:
185
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
186
+ Returns:
187
+ flipped clip (torch.tensor): Size is (T, C, H, W)
188
+ """
189
+ if not _is_tensor_video_clip(clip):
190
+ raise ValueError("clip should be a 4D torch.tensor")
191
+ return clip.flip(-1)
192
+
193
+
194
+ class RandomCropVideo:
195
+ def __init__(self, size):
196
+ if isinstance(size, numbers.Number):
197
+ self.size = (int(size), int(size))
198
+ else:
199
+ self.size = size
200
+
201
+ def __call__(self, clip):
202
+ """
203
+ Args:
204
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
205
+ Returns:
206
+ torch.tensor: randomly cropped video clip.
207
+ size is (T, C, OH, OW)
208
+ """
209
+ i, j, h, w = self.get_params(clip)
210
+ return crop(clip, i, j, h, w)
211
+
212
+ def get_params(self, clip):
213
+ h, w = clip.shape[-2:]
214
+ th, tw = self.size
215
+
216
+ if h < th or w < tw:
217
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
218
+
219
+ if w == tw and h == th:
220
+ return 0, 0, h, w
221
+
222
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
223
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
224
+
225
+ return i, j, th, tw
226
+
227
+ def __repr__(self) -> str:
228
+ return f"{self.__class__.__name__}(size={self.size})"
229
+
230
+ class CenterCropResizeVideo:
231
+ '''
232
+ First use the short side for cropping length,
233
+ center crop video, then resize to the specified size
234
+ '''
235
+ def __init__(
236
+ self,
237
+ size,
238
+ interpolation_mode="bilinear",
239
+ ):
240
+ if isinstance(size, tuple):
241
+ if len(size) != 2:
242
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
243
+ self.size = size
244
+ else:
245
+ self.size = (size, size)
246
+
247
+ self.interpolation_mode = interpolation_mode
248
+
249
+
250
+ def __call__(self, clip):
251
+ """
252
+ Args:
253
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
254
+ Returns:
255
+ torch.tensor: scale resized / center cropped video clip.
256
+ size is (T, C, crop_size, crop_size)
257
+ """
258
+ # print(clip.shape)
259
+ clip_center_crop = center_crop_using_short_edge(clip)
260
+ # print(clip_center_crop.shape) 320 512
261
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
262
+ return clip_center_crop_resize
263
+
264
+ def __repr__(self) -> str:
265
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
266
+
267
+
268
+ class CenterCropVideo:
269
+ def __init__(
270
+ self,
271
+ size,
272
+ interpolation_mode="bilinear",
273
+ ):
274
+ if isinstance(size, tuple):
275
+ if len(size) != 2:
276
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
277
+ self.size = size
278
+ else:
279
+ self.size = (size, size)
280
+
281
+ self.interpolation_mode = interpolation_mode
282
+
283
+
284
+ def __call__(self, clip):
285
+ """
286
+ Args:
287
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
288
+ Returns:
289
+ torch.tensor: center cropped video clip.
290
+ size is (T, C, crop_size, crop_size)
291
+ """
292
+ clip_center_crop = center_crop(clip, self.size)
293
+ return clip_center_crop
294
+
295
+ def __repr__(self) -> str:
296
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
297
+
298
+
299
+ class NormalizeVideo:
300
+ """
301
+ Normalize the video clip by mean subtraction and division by standard deviation
302
+ Args:
303
+ mean (3-tuple): pixel RGB mean
304
+ std (3-tuple): pixel RGB standard deviation
305
+ inplace (boolean): whether do in-place normalization
306
+ """
307
+
308
+ def __init__(self, mean, std, inplace=False):
309
+ self.mean = mean
310
+ self.std = std
311
+ self.inplace = inplace
312
+
313
+ def __call__(self, clip):
314
+ """
315
+ Args:
316
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
317
+ """
318
+ return normalize(clip, self.mean, self.std, self.inplace)
319
+
320
+ def __repr__(self) -> str:
321
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
322
+
323
+
324
+ class ToTensorVideo:
325
+ """
326
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
327
+ permute the dimensions of clip tensor
328
+ """
329
+
330
+ def __init__(self):
331
+ pass
332
+
333
+ def __call__(self, clip):
334
+ """
335
+ Args:
336
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
337
+ Return:
338
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
339
+ """
340
+ return to_tensor(clip)
341
+
342
+ def __repr__(self) -> str:
343
+ return self.__class__.__name__
344
+
345
+
346
+ class ResizeVideo():
347
+ '''
348
+ First use the short side for cropping length,
349
+ center crop video, then resize to the specified size
350
+ '''
351
+ def __init__(
352
+ self,
353
+ size,
354
+ interpolation_mode="bilinear",
355
+ ):
356
+ if isinstance(size, tuple):
357
+ if len(size) != 2:
358
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
359
+ self.size = size
360
+ else:
361
+ self.size = (size, size)
362
+
363
+ self.interpolation_mode = interpolation_mode
364
+
365
+
366
+ def __call__(self, clip):
367
+ """
368
+ Args:
369
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
370
+ Returns:
371
+ torch.tensor: scale resized / center cropped video clip.
372
+ size is (T, C, crop_size, crop_size)
373
+ """
374
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
375
+ return clip_resize
376
+
377
+ def __repr__(self) -> str:
378
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
379
+
380
+ # ------------------------------------------------------------
381
+ # --------------------- Sampling ---------------------------
382
+ # ------------------------------------------------------------
diffusion/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ # learn_sigma=True,
17
+ learn_sigma=False, # for unet
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
diffusion/__pycache__/diffusion_utils.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
diffusion/__pycache__/respace.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ # diffuser stable diffusion
115
+ # beta_start=scale * 0.00085,
116
+ # beta_end=scale * 0.012,
117
+ num_diffusion_timesteps=num_diffusion_timesteps,
118
+ )
119
+ elif schedule_name == "squaredcos_cap_v2":
120
+ return betas_for_alpha_bar(
121
+ num_diffusion_timesteps,
122
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
123
+ )
124
+ else:
125
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
126
+
127
+
128
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
129
+ """
130
+ Create a beta schedule that discretizes the given alpha_t_bar function,
131
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
132
+ :param num_diffusion_timesteps: the number of betas to produce.
133
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
134
+ produces the cumulative product of (1-beta) up to that
135
+ part of the diffusion process.
136
+ :param max_beta: the maximum beta to use; use values lower than 1 to
137
+ prevent singularities.
138
+ """
139
+ betas = []
140
+ for i in range(num_diffusion_timesteps):
141
+ t1 = i / num_diffusion_timesteps
142
+ t2 = (i + 1) / num_diffusion_timesteps
143
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
144
+ return np.array(betas)
145
+
146
+
147
+ class GaussianDiffusion:
148
+ """
149
+ Utilities for training and sampling diffusion models.
150
+ Original ported from this codebase:
151
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
152
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
153
+ starting at T and going to 1.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ *,
159
+ betas,
160
+ model_mean_type,
161
+ model_var_type,
162
+ loss_type
163
+ ):
164
+
165
+ self.model_mean_type = model_mean_type
166
+ self.model_var_type = model_var_type
167
+ self.loss_type = loss_type
168
+
169
+ # Use float64 for accuracy.
170
+ betas = np.array(betas, dtype=np.float64)
171
+ self.betas = betas
172
+ assert len(betas.shape) == 1, "betas must be 1-D"
173
+ assert (betas > 0).all() and (betas <= 1).all()
174
+
175
+ self.num_timesteps = int(betas.shape[0])
176
+
177
+ alphas = 1.0 - betas
178
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
179
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
180
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
181
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
182
+
183
+ # calculations for diffusion q(x_t | x_{t-1}) and others
184
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
185
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
186
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
187
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
188
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
189
+
190
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
191
+ self.posterior_variance = (
192
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
193
+ )
194
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
195
+ self.posterior_log_variance_clipped = np.log(
196
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
197
+ ) if len(self.posterior_variance) > 1 else np.array([])
198
+
199
+ self.posterior_mean_coef1 = (
200
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
201
+ )
202
+ self.posterior_mean_coef2 = (
203
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
204
+ )
205
+
206
+ def q_mean_variance(self, x_start, t):
207
+ """
208
+ Get the distribution q(x_t | x_0).
209
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
210
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
211
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
212
+ """
213
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
214
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
215
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
216
+ return mean, variance, log_variance
217
+
218
+ def q_sample(self, x_start, t, noise=None):
219
+ """
220
+ Diffuse the data for a given number of diffusion steps.
221
+ In other words, sample from q(x_t | x_0).
222
+ :param x_start: the initial data batch.
223
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
224
+ :param noise: if specified, the split-out normal noise.
225
+ :return: A noisy version of x_start.
226
+ """
227
+ if noise is None:
228
+ noise = th.randn_like(x_start)
229
+ assert noise.shape == x_start.shape
230
+ return (
231
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
232
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
233
+ )
234
+
235
+ def q_posterior_mean_variance(self, x_start, x_t, t):
236
+ """
237
+ Compute the mean and variance of the diffusion posterior:
238
+ q(x_{t-1} | x_t, x_0)
239
+ """
240
+ assert x_start.shape == x_t.shape
241
+ posterior_mean = (
242
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
243
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
244
+ )
245
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
246
+ posterior_log_variance_clipped = _extract_into_tensor(
247
+ self.posterior_log_variance_clipped, t, x_t.shape
248
+ )
249
+ assert (
250
+ posterior_mean.shape[0]
251
+ == posterior_variance.shape[0]
252
+ == posterior_log_variance_clipped.shape[0]
253
+ == x_start.shape[0]
254
+ )
255
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
256
+
257
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
258
+ mask=None, x_start=None, use_concat=False):
259
+ """
260
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
261
+ the initial x, x_0.
262
+ :param model: the model, which takes a signal and a batch of timesteps
263
+ as input.
264
+ :param x: the [N x C x ...] tensor at time t.
265
+ :param t: a 1-D Tensor of timesteps.
266
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
267
+ :param denoised_fn: if not None, a function which applies to the
268
+ x_start prediction before it is used to sample. Applies before
269
+ clip_denoised.
270
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
271
+ pass to the model. This can be used for conditioning.
272
+ :return: a dict with the following keys:
273
+ - 'mean': the model mean output.
274
+ - 'variance': the model variance output.
275
+ - 'log_variance': the log of 'variance'.
276
+ - 'pred_xstart': the prediction for x_0.
277
+ """
278
+ if model_kwargs is None:
279
+ model_kwargs = {}
280
+
281
+ B, F, C = x.shape[:3]
282
+ assert t.shape == (B,)
283
+ if use_concat:
284
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
285
+ else:
286
+ model_output = model(x, t, **model_kwargs)
287
+ try:
288
+ model_output = model_output.sample # for tav unet
289
+ except:
290
+ pass
291
+ # model_output = model(x, t, **model_kwargs)
292
+ if isinstance(model_output, tuple):
293
+ model_output, extra = model_output
294
+ else:
295
+ extra = None
296
+
297
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
298
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
299
+ model_output, model_var_values = th.split(model_output, C, dim=2)
300
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
301
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
302
+ # The model_var_values is [-1, 1] for [min_var, max_var].
303
+ frac = (model_var_values + 1) / 2
304
+ model_log_variance = frac * max_log + (1 - frac) * min_log
305
+ model_variance = th.exp(model_log_variance)
306
+ else:
307
+ model_variance, model_log_variance = {
308
+ # for fixedlarge, we set the initial (log-)variance like so
309
+ # to get a better decoder log likelihood.
310
+ ModelVarType.FIXED_LARGE: (
311
+ np.append(self.posterior_variance[1], self.betas[1:]),
312
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
313
+ ),
314
+ ModelVarType.FIXED_SMALL: (
315
+ self.posterior_variance,
316
+ self.posterior_log_variance_clipped,
317
+ ),
318
+ }[self.model_var_type]
319
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
320
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
321
+
322
+ def process_xstart(x):
323
+ if denoised_fn is not None:
324
+ x = denoised_fn(x)
325
+ if clip_denoised:
326
+ return x.clamp(-1, 1)
327
+ return x
328
+
329
+ if self.model_mean_type == ModelMeanType.START_X:
330
+ pred_xstart = process_xstart(model_output)
331
+ else:
332
+ pred_xstart = process_xstart(
333
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
334
+ )
335
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
336
+
337
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
338
+ return {
339
+ "mean": model_mean,
340
+ "variance": model_variance,
341
+ "log_variance": model_log_variance,
342
+ "pred_xstart": pred_xstart,
343
+ "extra": extra,
344
+ }
345
+
346
+ def _predict_xstart_from_eps(self, x_t, t, eps):
347
+ assert x_t.shape == eps.shape
348
+ return (
349
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
350
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
351
+ )
352
+
353
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
354
+ return (
355
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
356
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
357
+
358
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute the mean for the previous step, given a function cond_fn that
361
+ computes the gradient of a conditional log probability with respect to
362
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
363
+ condition on y.
364
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
365
+ """
366
+ gradient = cond_fn(x, t, **model_kwargs)
367
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
368
+ return new_mean
369
+
370
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
371
+ """
372
+ Compute what the p_mean_variance output would have been, should the
373
+ model's score function be conditioned by cond_fn.
374
+ See condition_mean() for details on cond_fn.
375
+ Unlike condition_mean(), this instead uses the conditioning strategy
376
+ from Song et al (2020).
377
+ """
378
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
379
+
380
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
381
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
382
+
383
+ out = p_mean_var.copy()
384
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
385
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
386
+ return out
387
+
388
+ def p_sample(
389
+ self,
390
+ model,
391
+ x,
392
+ t,
393
+ clip_denoised=True,
394
+ denoised_fn=None,
395
+ cond_fn=None,
396
+ model_kwargs=None,
397
+ mask=None,
398
+ x_start=None,
399
+ use_concat=False
400
+ ):
401
+ """
402
+ Sample x_{t-1} from the model at the given timestep.
403
+ :param model: the model to sample from.
404
+ :param x: the current tensor at x_{t-1}.
405
+ :param t: the value of t, starting at 0 for the first diffusion step.
406
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
407
+ :param denoised_fn: if not None, a function which applies to the
408
+ x_start prediction before it is used to sample.
409
+ :param cond_fn: if not None, this is a gradient function that acts
410
+ similarly to the model.
411
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
412
+ pass to the model. This can be used for conditioning.
413
+ :return: a dict containing the following keys:
414
+ - 'sample': a random sample from the model.
415
+ - 'pred_xstart': a prediction of x_0.
416
+ """
417
+ out = self.p_mean_variance(
418
+ model,
419
+ x,
420
+ t,
421
+ clip_denoised=clip_denoised,
422
+ denoised_fn=denoised_fn,
423
+ model_kwargs=model_kwargs,
424
+ mask=mask,
425
+ x_start=x_start,
426
+ use_concat=use_concat
427
+ )
428
+ noise = th.randn_like(x)
429
+ nonzero_mask = (
430
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
431
+ ) # no noise when t == 0
432
+ if cond_fn is not None:
433
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
434
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
435
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
436
+
437
+ def p_sample_loop(
438
+ self,
439
+ model,
440
+ shape,
441
+ noise=None,
442
+ clip_denoised=True,
443
+ denoised_fn=None,
444
+ cond_fn=None,
445
+ model_kwargs=None,
446
+ device=None,
447
+ progress=False,
448
+ mask=None,
449
+ x_start=None,
450
+ use_concat=False,
451
+ ):
452
+ """
453
+ Generate samples from the model.
454
+ :param model: the model module.
455
+ :param shape: the shape of the samples, (N, C, H, W).
456
+ :param noise: if specified, the noise from the encoder to sample.
457
+ Should be of the same shape as `shape`.
458
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
459
+ :param denoised_fn: if not None, a function which applies to the
460
+ x_start prediction before it is used to sample.
461
+ :param cond_fn: if not None, this is a gradient function that acts
462
+ similarly to the model.
463
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
464
+ pass to the model. This can be used for conditioning.
465
+ :param device: if specified, the device to create the samples on.
466
+ If not specified, use a model parameter's device.
467
+ :param progress: if True, show a tqdm progress bar.
468
+ :return: a non-differentiable batch of samples.
469
+ """
470
+ final = None
471
+ for sample in self.p_sample_loop_progressive(
472
+ model,
473
+ shape,
474
+ noise=noise,
475
+ clip_denoised=clip_denoised,
476
+ denoised_fn=denoised_fn,
477
+ cond_fn=cond_fn,
478
+ model_kwargs=model_kwargs,
479
+ device=device,
480
+ progress=progress,
481
+ mask=mask,
482
+ x_start=x_start,
483
+ use_concat=use_concat
484
+ ):
485
+ final = sample
486
+ return final["sample"]
487
+
488
+ def p_sample_loop_progressive(
489
+ self,
490
+ model,
491
+ shape,
492
+ noise=None,
493
+ clip_denoised=True,
494
+ denoised_fn=None,
495
+ cond_fn=None,
496
+ model_kwargs=None,
497
+ device=None,
498
+ progress=False,
499
+ mask=None,
500
+ x_start=None,
501
+ use_concat=False
502
+ ):
503
+ """
504
+ Generate samples from the model and yield intermediate samples from
505
+ each timestep of diffusion.
506
+ Arguments are the same as p_sample_loop().
507
+ Returns a generator over dicts, where each dict is the return value of
508
+ p_sample().
509
+ """
510
+ if device is None:
511
+ device = next(model.parameters()).device
512
+ assert isinstance(shape, (tuple, list))
513
+ if noise is not None:
514
+ img = noise
515
+ else:
516
+ img = th.randn(*shape, device=device)
517
+ indices = list(range(self.num_timesteps))[::-1]
518
+
519
+ if progress:
520
+ # Lazy import so that we don't depend on tqdm.
521
+ from tqdm.auto import tqdm
522
+
523
+ indices = tqdm(indices)
524
+
525
+ for i in indices:
526
+ t = th.tensor([i] * shape[0], device=device)
527
+ with th.no_grad():
528
+ out = self.p_sample(
529
+ model,
530
+ img,
531
+ t,
532
+ clip_denoised=clip_denoised,
533
+ denoised_fn=denoised_fn,
534
+ cond_fn=cond_fn,
535
+ model_kwargs=model_kwargs,
536
+ mask=mask,
537
+ x_start=x_start,
538
+ use_concat=use_concat
539
+ )
540
+ yield out
541
+ img = out["sample"]
542
+
543
+ def ddim_sample(
544
+ self,
545
+ model,
546
+ x,
547
+ t,
548
+ clip_denoised=True,
549
+ denoised_fn=None,
550
+ cond_fn=None,
551
+ model_kwargs=None,
552
+ eta=0.0,
553
+ mask=None,
554
+ x_start=None,
555
+ use_concat=False
556
+ ):
557
+ """
558
+ Sample x_{t-1} from the model using DDIM.
559
+ Same usage as p_sample().
560
+ """
561
+ out = self.p_mean_variance(
562
+ model,
563
+ x,
564
+ t,
565
+ clip_denoised=clip_denoised,
566
+ denoised_fn=denoised_fn,
567
+ model_kwargs=model_kwargs,
568
+ mask=mask,
569
+ x_start=x_start,
570
+ use_concat=use_concat
571
+ )
572
+ if cond_fn is not None:
573
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
574
+
575
+ # Usually our model outputs epsilon, but we re-derive it
576
+ # in case we used x_start or x_prev prediction.
577
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
578
+
579
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
580
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
581
+ sigma = (
582
+ eta
583
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
584
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
585
+ )
586
+ # Equation 12.
587
+ noise = th.randn_like(x)
588
+ mean_pred = (
589
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
590
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
591
+ )
592
+ nonzero_mask = (
593
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
594
+ ) # no noise when t == 0
595
+ sample = mean_pred + nonzero_mask * sigma * noise
596
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
597
+
598
+ def ddim_reverse_sample(
599
+ self,
600
+ model,
601
+ x,
602
+ t,
603
+ clip_denoised=True,
604
+ denoised_fn=None,
605
+ cond_fn=None,
606
+ model_kwargs=None,
607
+ eta=0.0,
608
+ ):
609
+ """
610
+ Sample x_{t+1} from the model using DDIM reverse ODE.
611
+ """
612
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
613
+ out = self.p_mean_variance(
614
+ model,
615
+ x,
616
+ t,
617
+ clip_denoised=clip_denoised,
618
+ denoised_fn=denoised_fn,
619
+ model_kwargs=model_kwargs,
620
+ )
621
+ if cond_fn is not None:
622
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
623
+ # Usually our model outputs epsilon, but we re-derive it
624
+ # in case we used x_start or x_prev prediction.
625
+ eps = (
626
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
627
+ - out["pred_xstart"]
628
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
629
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
630
+
631
+ # Equation 12. reversed
632
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
633
+
634
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
635
+
636
+ def ddim_sample_loop(
637
+ self,
638
+ model,
639
+ shape,
640
+ noise=None,
641
+ clip_denoised=True,
642
+ denoised_fn=None,
643
+ cond_fn=None,
644
+ model_kwargs=None,
645
+ device=None,
646
+ progress=False,
647
+ eta=0.0,
648
+ mask=None,
649
+ x_start=None,
650
+ use_concat=False
651
+ ):
652
+ """
653
+ Generate samples from the model using DDIM.
654
+ Same usage as p_sample_loop().
655
+ """
656
+ final = None
657
+ for sample in self.ddim_sample_loop_progressive(
658
+ model,
659
+ shape,
660
+ noise=noise,
661
+ clip_denoised=clip_denoised,
662
+ denoised_fn=denoised_fn,
663
+ cond_fn=cond_fn,
664
+ model_kwargs=model_kwargs,
665
+ device=device,
666
+ progress=progress,
667
+ eta=eta,
668
+ mask=mask,
669
+ x_start=x_start,
670
+ use_concat=use_concat
671
+ ):
672
+ final = sample
673
+ return final["sample"]
674
+
675
+ def ddim_sample_loop_progressive(
676
+ self,
677
+ model,
678
+ shape,
679
+ noise=None,
680
+ clip_denoised=True,
681
+ denoised_fn=None,
682
+ cond_fn=None,
683
+ model_kwargs=None,
684
+ device=None,
685
+ progress=False,
686
+ eta=0.0,
687
+ mask=None,
688
+ x_start=None,
689
+ use_concat=False
690
+ ):
691
+ """
692
+ Use DDIM to sample from the model and yield intermediate samples from
693
+ each timestep of DDIM.
694
+ Same usage as p_sample_loop_progressive().
695
+ """
696
+ if device is None:
697
+ device = next(model.parameters()).device
698
+ assert isinstance(shape, (tuple, list))
699
+ if noise is not None:
700
+ img = noise
701
+ else:
702
+ img = th.randn(*shape, device=device)
703
+ indices = list(range(self.num_timesteps))[::-1]
704
+
705
+ if progress:
706
+ # Lazy import so that we don't depend on tqdm.
707
+ from tqdm.auto import tqdm
708
+
709
+ indices = tqdm(indices)
710
+
711
+ for i in indices:
712
+ t = th.tensor([i] * shape[0], device=device)
713
+ with th.no_grad():
714
+ out = self.ddim_sample(
715
+ model,
716
+ img,
717
+ t,
718
+ clip_denoised=clip_denoised,
719
+ denoised_fn=denoised_fn,
720
+ cond_fn=cond_fn,
721
+ model_kwargs=model_kwargs,
722
+ eta=eta,
723
+ mask=mask,
724
+ x_start=x_start,
725
+ use_concat=use_concat
726
+ )
727
+ yield out
728
+ img = out["sample"]
729
+
730
+ def _vb_terms_bpd(
731
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
732
+ ):
733
+ """
734
+ Get a term for the variational lower-bound.
735
+ The resulting units are bits (rather than nats, as one might expect).
736
+ This allows for comparison to other papers.
737
+ :return: a dict with the following keys:
738
+ - 'output': a shape [N] tensor of NLLs or KLs.
739
+ - 'pred_xstart': the x_0 predictions.
740
+ """
741
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
742
+ x_start=x_start, x_t=x_t, t=t
743
+ )
744
+ out = self.p_mean_variance(
745
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
746
+ )
747
+ kl = normal_kl(
748
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
749
+ )
750
+ kl = mean_flat(kl) / np.log(2.0)
751
+
752
+ decoder_nll = -discretized_gaussian_log_likelihood(
753
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
754
+ )
755
+ assert decoder_nll.shape == x_start.shape
756
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
757
+
758
+ # At the first timestep return the decoder NLL,
759
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
760
+ output = th.where((t == 0), decoder_nll, kl)
761
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
762
+
763
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False):
764
+ """
765
+ Compute training losses for a single timestep.
766
+ :param model: the model to evaluate loss on.
767
+ :param x_start: the [N x C x ...] tensor of inputs.
768
+ :param t: a batch of timestep indices.
769
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
770
+ pass to the model. This can be used for conditioning.
771
+ :param noise: if specified, the specific Gaussian noise to try to remove.
772
+ :return: a dict with the key "loss" containing a tensor of shape [N].
773
+ Some mean or variance settings may also have other keys.
774
+ """
775
+ if model_kwargs is None:
776
+ model_kwargs = {}
777
+ if noise is None:
778
+ noise = th.randn_like(x_start)
779
+ x_t = self.q_sample(x_start, t, noise=noise)
780
+ if use_mask:
781
+ x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
782
+ terms = {}
783
+
784
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
785
+ terms["loss"] = self._vb_terms_bpd(
786
+ model=model,
787
+ x_start=x_start,
788
+ x_t=x_t,
789
+ t=t,
790
+ clip_denoised=False,
791
+ model_kwargs=model_kwargs,
792
+ )["output"]
793
+ if self.loss_type == LossType.RESCALED_KL:
794
+ terms["loss"] *= self.num_timesteps
795
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
796
+ model_output = model(x_t, t, **model_kwargs)
797
+ try:
798
+ # model_output = model(x_t, t, **model_kwargs).sample
799
+ model_output = model_output.sample # for tav unet
800
+ except:
801
+ pass
802
+ # model_output = model(x_t, t, **model_kwargs)
803
+
804
+ if self.model_var_type in [
805
+ ModelVarType.LEARNED,
806
+ ModelVarType.LEARNED_RANGE,
807
+ ]:
808
+ B, F, C = x_t.shape[:3]
809
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
810
+ model_output, model_var_values = th.split(model_output, C, dim=2)
811
+ # Learn the variance using the variational bound, but don't let
812
+ # it affect our mean prediction.
813
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
814
+ terms["vb"] = self._vb_terms_bpd(
815
+ model=lambda *args, r=frozen_out: r,
816
+ x_start=x_start,
817
+ x_t=x_t,
818
+ t=t,
819
+ clip_denoised=False,
820
+ )["output"]
821
+ if self.loss_type == LossType.RESCALED_MSE:
822
+ # Divide by 1000 for equivalence with initial implementation.
823
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
824
+ terms["vb"] *= self.num_timesteps / 1000.0
825
+
826
+ target = {
827
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
828
+ x_start=x_start, x_t=x_t, t=t
829
+ )[0],
830
+ ModelMeanType.START_X: x_start,
831
+ ModelMeanType.EPSILON: noise,
832
+ }[self.model_mean_type]
833
+ # assert model_output.shape == target.shape == x_start.shape
834
+ if use_mask:
835
+ terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
836
+ else:
837
+ terms["mse"] = mean_flat((target - model_output) ** 2)
838
+ if "vb" in terms:
839
+ terms["loss"] = terms["mse"] + terms["vb"]
840
+ else:
841
+ terms["loss"] = terms["mse"]
842
+ else:
843
+ raise NotImplementedError(self.loss_type)
844
+
845
+ return terms
846
+
847
+ def _prior_bpd(self, x_start):
848
+ """
849
+ Get the prior KL term for the variational lower-bound, measured in
850
+ bits-per-dim.
851
+ This term can't be optimized, as it only depends on the encoder.
852
+ :param x_start: the [N x C x ...] tensor of inputs.
853
+ :return: a batch of [N] KL values (in bits), one per batch element.
854
+ """
855
+ batch_size = x_start.shape[0]
856
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
857
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
858
+ kl_prior = normal_kl(
859
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
860
+ )
861
+ return mean_flat(kl_prior) / np.log(2.0)
862
+
863
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
864
+ """
865
+ Compute the entire variational lower-bound, measured in bits-per-dim,
866
+ as well as other related quantities.
867
+ :param model: the model to evaluate loss on.
868
+ :param x_start: the [N x C x ...] tensor of inputs.
869
+ :param clip_denoised: if True, clip denoised samples.
870
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
871
+ pass to the model. This can be used for conditioning.
872
+ :return: a dict containing the following keys:
873
+ - total_bpd: the total variational lower-bound, per batch element.
874
+ - prior_bpd: the prior term in the lower-bound.
875
+ - vb: an [N x T] tensor of terms in the lower-bound.
876
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
877
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
878
+ """
879
+ device = x_start.device
880
+ batch_size = x_start.shape[0]
881
+
882
+ vb = []
883
+ xstart_mse = []
884
+ mse = []
885
+ for t in list(range(self.num_timesteps))[::-1]:
886
+ t_batch = th.tensor([t] * batch_size, device=device)
887
+ noise = th.randn_like(x_start)
888
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
889
+ # Calculate VLB term at the current timestep
890
+ with th.no_grad():
891
+ out = self._vb_terms_bpd(
892
+ model,
893
+ x_start=x_start,
894
+ x_t=x_t,
895
+ t=t_batch,
896
+ clip_denoised=clip_denoised,
897
+ model_kwargs=model_kwargs,
898
+ )
899
+ vb.append(out["output"])
900
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
901
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
902
+ mse.append(mean_flat((eps - noise) ** 2))
903
+
904
+ vb = th.stack(vb, dim=1)
905
+ xstart_mse = th.stack(xstart_mse, dim=1)
906
+ mse = th.stack(mse, dim=1)
907
+
908
+ prior_bpd = self._prior_bpd(x_start)
909
+ total_bpd = vb.sum(dim=1) + prior_bpd
910
+ return {
911
+ "total_bpd": total_bpd,
912
+ "prior_bpd": prior_bpd,
913
+ "vb": vb,
914
+ "xstart_mse": xstart_mse,
915
+ "mse": mse,
916
+ }
917
+
918
+
919
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
920
+ """
921
+ Extract values from a 1-D numpy array for a batch of indices.
922
+ :param arr: the 1-D numpy array.
923
+ :param timesteps: a tensor of indices into the array to extract.
924
+ :param broadcast_shape: a larger shape of K dimensions with the batch
925
+ dimension equal to the length of timesteps.
926
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
927
+ """
928
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
929
+ while len(res.shape) < len(broadcast_shape):
930
+ res = res[..., None]
931
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ import torch
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ # @torch.compile
95
+ def training_losses(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
99
+
100
+ def condition_mean(self, cond_fn, *args, **kwargs):
101
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102
+
103
+ def condition_score(self, cond_fn, *args, **kwargs):
104
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105
+
106
+ def _wrap_model(self, model):
107
+ if isinstance(model, _WrappedModel):
108
+ return model
109
+ return _WrappedModel(
110
+ model, self.timestep_map, self.original_num_steps
111
+ )
112
+
113
+ def _scale_timesteps(self, t):
114
+ # Scaling is done by the wrapped model.
115
+ return t
116
+
117
+
118
+ class _WrappedModel:
119
+ def __init__(self, model, timestep_map, original_num_steps):
120
+ self.model = model
121
+ self.timestep_map = timestep_map
122
+ # self.rescale_timesteps = rescale_timesteps
123
+ self.original_num_steps = original_num_steps
124
+
125
+ def __call__(self, x, ts, **kwargs):
126
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127
+ new_ts = map_tensor[ts]
128
+ # if self.rescale_timesteps:
129
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130
+ return self.model(x, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png ADDED

Git LFS Details

  • SHA256: 29c1a37e9328d53826f4dcac1fc9412b1c146f1c613ae70997b079ae0a2dd8c5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB
input/i2v/A_fish_swims_past_an_oriental_woman.png ADDED

Git LFS Details

  • SHA256: 243573749f66d6c8368ad8ef6443f165030b1aea9042270d541ae7247b7462d5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
input/i2v/Cinematic_photograph_View_of_piloting_aaero.png ADDED

Git LFS Details

  • SHA256: af08c3dea09a87c9656244ab52deeba18d781d2fca73c11d9cf2c9aba6fba0c8
  • Pointer size: 132 Bytes
  • Size of remote file: 7.4 MB
input/i2v/Planet_hits_earth.png ADDED

Git LFS Details

  • SHA256: 1488efae90efa98e2101deb6ed804056ad45327afb03f046df2512436a49a13e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
input/i2v/Underwater_environment_cosmetic_bottles.png ADDED

Git LFS Details

  • SHA256: f2dfb8d1084b41f410cbccb24c7a102a750bfa7f16b3bc0cd3d770ecf89fd4fd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.split(sys.path[0])[0])
4
+
5
+ from .unet import UNet3DConditionModel
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+
8
+ def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ def fn(step):
11
+ if warmup_steps > 0:
12
+ return min(step / warmup_steps, 1)
13
+ else:
14
+ return 1
15
+ return LambdaLR(optimizer, fn)
16
+
17
+
18
+ def get_lr_scheduler(optimizer, name, **kwargs):
19
+ if name == 'warmup':
20
+ return customized_lr_scheduler(optimizer, **kwargs)
21
+ elif name == 'cosine':
22
+ from torch.optim.lr_scheduler import CosineAnnealingLR
23
+ return CosineAnnealingLR(optimizer, **kwargs)
24
+ else:
25
+ raise NotImplementedError(name)
26
+
27
+ def get_models(args):
28
+ if 'UNet' in args.model:
29
+ pretrained_model_path = args.pretrained_model_path
30
+ return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
31
+ else:
32
+ raise '{} Model Not Supported!'.format(args.model)
33
+
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (22.3 kB). View file
 
models/__pycache__/clip.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (5.17 kB). View file
 
models/__pycache__/unet.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
models/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
models/attention.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from copy import deepcopy
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.utils import BaseOutput
15
+ from diffusers.utils.import_utils import is_xformers_available
16
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
17
+ from rotary_embedding_torch import RotaryEmbedding
18
+ from typing import Callable, Optional
19
+ from einops import rearrange, repeat
20
+
21
+ try:
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ except:
24
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
25
+
26
+
27
+ @dataclass
28
+ class Transformer3DModelOutput(BaseOutput):
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ if is_xformers_available():
33
+ import xformers
34
+ import xformers.ops
35
+ else:
36
+ xformers = None
37
+
38
+ def exists(x):
39
+ return x is not None
40
+
41
+
42
+ class CrossAttention(nn.Module):
43
+ r"""
44
+ copy from diffuser 0.11.1
45
+ A cross attention layer.
46
+ Parameters:
47
+ query_dim (`int`): The number of channels in the query.
48
+ cross_attention_dim (`int`, *optional*):
49
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
50
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
51
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
52
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
53
+ bias (`bool`, *optional*, defaults to False):
54
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ query_dim: int,
60
+ cross_attention_dim: Optional[int] = None,
61
+ heads: int = 8,
62
+ dim_head: int = 64,
63
+ dropout: float = 0.0,
64
+ bias=False,
65
+ upcast_attention: bool = False,
66
+ upcast_softmax: bool = False,
67
+ added_kv_proj_dim: Optional[int] = None,
68
+ norm_num_groups: Optional[int] = None,
69
+ use_relative_position: bool = False,
70
+ ):
71
+ super().__init__()
72
+ # print('num head', heads)
73
+ inner_dim = dim_head * heads
74
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
75
+ self.upcast_attention = upcast_attention
76
+ self.upcast_softmax = upcast_softmax
77
+
78
+ self.scale = dim_head**-0.5
79
+
80
+ self.heads = heads
81
+ self.dim_head = dim_head
82
+ # for slice_size > 0 the attention score computation
83
+ # is split across the batch axis to save memory
84
+ # You can set slice_size with `set_attention_slice`
85
+ self.sliceable_head_dim = heads
86
+ self._slice_size = None
87
+ self._use_memory_efficient_attention_xformers = False
88
+ self.added_kv_proj_dim = added_kv_proj_dim
89
+
90
+ if norm_num_groups is not None:
91
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
92
+ else:
93
+ self.group_norm = None
94
+
95
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
96
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
97
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
98
+
99
+ if self.added_kv_proj_dim is not None:
100
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
101
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
102
+
103
+ self.to_out = nn.ModuleList([])
104
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
105
+ self.to_out.append(nn.Dropout(dropout))
106
+
107
+ # print(use_relative_position)
108
+ self.use_relative_position = use_relative_position
109
+ if self.use_relative_position:
110
+ self.rotary_emb = RotaryEmbedding(min(32, dim_head))
111
+
112
+ self.ip_transformed = False
113
+ self.ip_scale = 1
114
+
115
+ def ip_transform(self):
116
+ if self.ip_transformed is not True:
117
+ self.ip_to_k = deepcopy(self.to_k).to(next(self.parameters()).device)
118
+ self.ip_to_v = deepcopy(self.to_v).to(next(self.parameters()).device)
119
+ self.ip_transformed = True
120
+
121
+ def ip_train_set(self):
122
+ if self.ip_transformed is True:
123
+ self.ip_to_k.requires_grad_(True)
124
+ self.ip_to_v.requires_grad_(True)
125
+
126
+ def set_scale(self, scale):
127
+ self.ip_scale = scale
128
+
129
+ def reshape_heads_to_batch_dim(self, tensor):
130
+ batch_size, seq_len, dim = tensor.shape
131
+ head_size = self.heads
132
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
133
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
134
+ return tensor
135
+
136
+ def reshape_batch_dim_to_heads(self, tensor):
137
+ batch_size, seq_len, dim = tensor.shape
138
+ head_size = self.heads
139
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
140
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
141
+ return tensor
142
+
143
+ def reshape_for_scores(self, tensor):
144
+ # split heads and dims
145
+ # tensor should be [b (h w)] f (d nd)
146
+ batch_size, seq_len, dim = tensor.shape
147
+ head_size = self.heads
148
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
149
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
150
+ return tensor
151
+
152
+ def same_batch_dim_to_heads(self, tensor):
153
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
154
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
155
+ return tensor
156
+
157
+ def set_attention_slice(self, slice_size):
158
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
159
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
160
+
161
+ self._slice_size = slice_size
162
+
163
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None):
164
+ batch_size, sequence_length, _ = hidden_states.shape
165
+
166
+ encoder_hidden_states = encoder_hidden_states
167
+
168
+ if self.group_norm is not None:
169
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
170
+
171
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
172
+
173
+ dim = query.shape[-1]
174
+ if not self.use_relative_position:
175
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
176
+
177
+ if self.added_kv_proj_dim is not None:
178
+ key = self.to_k(hidden_states)
179
+ value = self.to_v(hidden_states)
180
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
181
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
182
+
183
+ key = self.reshape_heads_to_batch_dim(key)
184
+ value = self.reshape_heads_to_batch_dim(value)
185
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
186
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
187
+
188
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
189
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
190
+ else:
191
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
192
+ key = self.to_k(encoder_hidden_states)
193
+ value = self.to_v(encoder_hidden_states)
194
+
195
+ if not self.use_relative_position:
196
+ key = self.reshape_heads_to_batch_dim(key)
197
+ value = self.reshape_heads_to_batch_dim(value)
198
+
199
+ if self.ip_transformed is True and ip_hidden_states is not None:
200
+ # print(ip_hidden_states.dtype)
201
+ # print(self.ip_to_k.weight.dtype)
202
+ ip_key = self.ip_to_k(ip_hidden_states)
203
+ ip_value = self.ip_to_v(ip_hidden_states)
204
+
205
+ if not self.use_relative_position:
206
+ ip_key = self.reshape_heads_to_batch_dim(ip_key)
207
+ ip_value = self.reshape_heads_to_batch_dim(ip_value)
208
+
209
+ if attention_mask is not None:
210
+ if attention_mask.shape[-1] != query.shape[1]:
211
+ target_length = query.shape[1]
212
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
213
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
214
+
215
+ # attention, what we cannot get enough of
216
+ if self._use_memory_efficient_attention_xformers:
217
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
218
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
219
+ hidden_states = hidden_states.to(query.dtype)
220
+
221
+ if self.ip_transformed is True and ip_hidden_states is not None:
222
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, attention_mask)
223
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
224
+
225
+ else:
226
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
227
+ hidden_states = self._attention(query, key, value, attention_mask)
228
+
229
+ if self.ip_transformed is True and ip_hidden_states is not None:
230
+ ip_hidden_states = self._attention(query, ip_key, ip_value, attention_mask)
231
+ else:
232
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
233
+
234
+ if self.ip_transformed is True and ip_hidden_states is not None:
235
+ ip_hidden_states = self._sliced_attention(query, ip_key, ip_value, sequence_length, dim, attention_mask)
236
+
237
+ if self.ip_transformed is True and ip_hidden_states is not None:
238
+ hidden_states = hidden_states + self.ip_scale * ip_hidden_states
239
+
240
+ # linear proj
241
+ hidden_states = self.to_out[0](hidden_states)
242
+
243
+ # dropout
244
+ hidden_states = self.to_out[1](hidden_states)
245
+ return hidden_states
246
+
247
+
248
+ def _attention(self, query, key, value, attention_mask=None):
249
+ if self.upcast_attention:
250
+ query = query.float()
251
+ key = key.float()
252
+
253
+ attention_scores = torch.baddbmm(
254
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
255
+ query,
256
+ key.transpose(-1, -2),
257
+ beta=0,
258
+ alpha=self.scale,
259
+ )
260
+
261
+ if attention_mask is not None:
262
+ attention_scores = attention_scores + attention_mask
263
+
264
+ if self.upcast_softmax:
265
+ attention_scores = attention_scores.float()
266
+
267
+ attention_probs = attention_scores.softmax(dim=-1)
268
+ attention_probs = attention_probs.to(value.dtype)
269
+ hidden_states = torch.bmm(attention_probs, value)
270
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
271
+ return hidden_states
272
+
273
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
274
+ batch_size_attention = query.shape[0]
275
+ hidden_states = torch.zeros(
276
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
277
+ )
278
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
279
+ for i in range(hidden_states.shape[0] // slice_size):
280
+ start_idx = i * slice_size
281
+ end_idx = (i + 1) * slice_size
282
+
283
+ query_slice = query[start_idx:end_idx]
284
+ key_slice = key[start_idx:end_idx]
285
+
286
+ if self.upcast_attention:
287
+ query_slice = query_slice.float()
288
+ key_slice = key_slice.float()
289
+
290
+ attn_slice = torch.baddbmm(
291
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
292
+ query_slice,
293
+ key_slice.transpose(-1, -2),
294
+ beta=0,
295
+ alpha=self.scale,
296
+ )
297
+
298
+ if attention_mask is not None:
299
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
300
+
301
+ if self.upcast_softmax:
302
+ attn_slice = attn_slice.float()
303
+
304
+ attn_slice = attn_slice.softmax(dim=-1)
305
+
306
+ # cast back to the original dtype
307
+ attn_slice = attn_slice.to(value.dtype)
308
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
309
+
310
+ hidden_states[start_idx:end_idx] = attn_slice
311
+
312
+ # reshape hidden_states
313
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
314
+ return hidden_states
315
+
316
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
317
+ # TODO attention_mask
318
+ query = query.contiguous()
319
+ key = key.contiguous()
320
+ value = value.contiguous()
321
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
322
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
323
+ return hidden_states
324
+
325
+
326
+ class Transformer3DModel(ModelMixin, ConfigMixin):
327
+ @register_to_config
328
+ def __init__(
329
+ self,
330
+ num_attention_heads: int = 16,
331
+ attention_head_dim: int = 88,
332
+ in_channels: Optional[int] = None,
333
+ num_layers: int = 1,
334
+ dropout: float = 0.0,
335
+ norm_num_groups: int = 32,
336
+ cross_attention_dim: Optional[int] = None,
337
+ attention_bias: bool = False,
338
+ activation_fn: str = "geglu",
339
+ num_embeds_ada_norm: Optional[int] = None,
340
+ use_linear_projection: bool = False,
341
+ only_cross_attention: bool = False,
342
+ upcast_attention: bool = False,
343
+ use_first_frame: bool = False,
344
+ use_relative_position: bool = False,
345
+ rotary_emb: bool = None,
346
+ ):
347
+ super().__init__()
348
+ self.use_linear_projection = use_linear_projection
349
+ self.num_attention_heads = num_attention_heads
350
+ self.attention_head_dim = attention_head_dim
351
+ inner_dim = num_attention_heads * attention_head_dim
352
+
353
+ # Define input layers
354
+ self.in_channels = in_channels
355
+
356
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
357
+ if use_linear_projection:
358
+ self.proj_in = nn.Linear(in_channels, inner_dim)
359
+ else:
360
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
361
+
362
+ # Define transformers blocks
363
+ self.transformer_blocks = nn.ModuleList(
364
+ [
365
+ BasicTransformerBlock(
366
+ inner_dim,
367
+ num_attention_heads,
368
+ attention_head_dim,
369
+ dropout=dropout,
370
+ cross_attention_dim=cross_attention_dim,
371
+ activation_fn=activation_fn,
372
+ num_embeds_ada_norm=num_embeds_ada_norm,
373
+ attention_bias=attention_bias,
374
+ only_cross_attention=only_cross_attention,
375
+ upcast_attention=upcast_attention,
376
+ use_first_frame=use_first_frame,
377
+ use_relative_position=use_relative_position,
378
+ rotary_emb=rotary_emb,
379
+ )
380
+ for d in range(num_layers)
381
+ ]
382
+ )
383
+
384
+ # 4. Define output layers
385
+ if use_linear_projection:
386
+ self.proj_out = nn.Linear(in_channels, inner_dim)
387
+ else:
388
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
389
+
390
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True, ip_hidden_states=None, encoder_temporal_hidden_states=None):
391
+ # Input
392
+ # if ip_hidden_states is not None:
393
+ # ip_hidden_states = ip_hidden_states.to(dtype=encoder_hidden_states.dtype)
394
+ # print(ip_hidden_states.shape)
395
+ # print(encoder_hidden_states.shape)
396
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
397
+ if self.training:
398
+ video_length = hidden_states.shape[2] - use_image_num
399
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
400
+ encoder_hidden_states_length = encoder_hidden_states.shape[1]
401
+ encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
402
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
403
+ encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
404
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
405
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
406
+
407
+ if ip_hidden_states is not None:
408
+ ip_hidden_states_length = ip_hidden_states.shape[1]
409
+ ip_hidden_states_video = ip_hidden_states[:, :ip_hidden_states_length - use_image_num, ...]
410
+ ip_hidden_states_video = repeat(ip_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
411
+ ip_hidden_states_image = ip_hidden_states[:, ip_hidden_states_length - use_image_num:, ...]
412
+ ip_hidden_states = torch.cat([ip_hidden_states_video, ip_hidden_states_image], dim=1)
413
+ ip_hidden_states = rearrange(ip_hidden_states, 'b m n c -> (b m) n c').contiguous()
414
+
415
+ else:
416
+ video_length = hidden_states.shape[2]
417
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
418
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
419
+
420
+ if encoder_temporal_hidden_states is not None:
421
+ encoder_temporal_hidden_states = repeat(encoder_temporal_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
422
+
423
+ if ip_hidden_states is not None:
424
+ ip_hidden_states = repeat(ip_hidden_states, 'b 1 n c -> (b f) n c', f=video_length).contiguous()
425
+
426
+ batch, channel, height, weight = hidden_states.shape
427
+ residual = hidden_states
428
+
429
+ hidden_states = self.norm(hidden_states)
430
+ if not self.use_linear_projection:
431
+ hidden_states = self.proj_in(hidden_states)
432
+ inner_dim = hidden_states.shape[1]
433
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
434
+ else:
435
+ inner_dim = hidden_states.shape[1]
436
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
437
+ hidden_states = self.proj_in(hidden_states)
438
+
439
+ # Blocks
440
+ for block in self.transformer_blocks:
441
+ hidden_states = block(
442
+ hidden_states,
443
+ encoder_hidden_states=encoder_hidden_states,
444
+ timestep=timestep,
445
+ video_length=video_length,
446
+ use_image_num=use_image_num,
447
+ ip_hidden_states=ip_hidden_states,
448
+ encoder_temporal_hidden_states=encoder_temporal_hidden_states
449
+ )
450
+
451
+ # Output
452
+ if not self.use_linear_projection:
453
+ hidden_states = (
454
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
455
+ )
456
+ hidden_states = self.proj_out(hidden_states)
457
+ else:
458
+ hidden_states = self.proj_out(hidden_states)
459
+ hidden_states = (
460
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
461
+ )
462
+
463
+ output = hidden_states + residual
464
+
465
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
466
+ if not return_dict:
467
+ return (output,)
468
+
469
+ return Transformer3DModelOutput(sample=output)
470
+
471
+
472
+ class BasicTransformerBlock(nn.Module):
473
+ def __init__(
474
+ self,
475
+ dim: int,
476
+ num_attention_heads: int,
477
+ attention_head_dim: int,
478
+ dropout=0.0,
479
+ cross_attention_dim: Optional[int] = None,
480
+ activation_fn: str = "geglu",
481
+ num_embeds_ada_norm: Optional[int] = None,
482
+ attention_bias: bool = False,
483
+ only_cross_attention: bool = False,
484
+ upcast_attention: bool = False,
485
+ use_first_frame: bool = False,
486
+ use_relative_position: bool = False,
487
+ rotary_emb: bool = False,
488
+ ):
489
+ super().__init__()
490
+ self.only_cross_attention = only_cross_attention
491
+ # print(only_cross_attention)
492
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
493
+ # print(self.use_ada_layer_norm)
494
+ self.use_first_frame = use_first_frame
495
+
496
+ self.dim = dim
497
+ self.cross_attention_dim = cross_attention_dim
498
+ self.num_attention_heads = num_attention_heads
499
+ self.attention_head_dim = attention_head_dim
500
+ self.dropout = dropout
501
+ self.attention_bias = attention_bias
502
+ self.upcast_attention = upcast_attention
503
+
504
+ # Spatial-Attn
505
+ self.attn1 = CrossAttention(
506
+ query_dim=dim,
507
+ heads=num_attention_heads,
508
+ dim_head=attention_head_dim,
509
+ dropout=dropout,
510
+ bias=attention_bias,
511
+ cross_attention_dim=None,
512
+ upcast_attention=upcast_attention,
513
+ )
514
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
515
+
516
+ # Text Cross-Attn
517
+ if cross_attention_dim is not None:
518
+ self.attn2 = CrossAttention(
519
+ query_dim=dim,
520
+ cross_attention_dim=cross_attention_dim,
521
+ heads=num_attention_heads,
522
+ dim_head=attention_head_dim,
523
+ dropout=dropout,
524
+ bias=attention_bias,
525
+ upcast_attention=upcast_attention,
526
+ )
527
+ else:
528
+ self.attn2 = None
529
+
530
+ if cross_attention_dim is not None:
531
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
532
+ else:
533
+ self.norm2 = None
534
+
535
+ # Temp
536
+ self.attn_temp = TemporalAttention(
537
+ query_dim=dim,
538
+ heads=num_attention_heads,
539
+ dim_head=attention_head_dim,
540
+ dropout=dropout,
541
+ bias=attention_bias,
542
+ cross_attention_dim=None,
543
+ upcast_attention=upcast_attention,
544
+ rotary_emb=rotary_emb,
545
+ )
546
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
547
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
548
+
549
+ # Feed-forward
550
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
551
+ self.norm3 = nn.LayerNorm(dim)
552
+
553
+ self.tca_transformed = False
554
+
555
+ def tca_transform(self):
556
+ if self.tca_transformed is not True:
557
+ self.cross_attn_temp = CrossAttention(
558
+ query_dim=self.dim * 16,
559
+ cross_attention_dim=self.cross_attention_dim,
560
+ heads=self.num_attention_heads,
561
+ dim_head=self.attention_head_dim,
562
+ dropout=self.dropout,
563
+ bias=self.attention_bias,
564
+ upcast_attention=self.upcast_attention,
565
+ )
566
+ self.cross_norm_temp = AdaLayerNorm(self.dim * 16, self.num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(self.dim * 16)
567
+ nn.init.zeros_(self.cross_attn_temp.to_out[0].weight.data)
568
+ self.tca_transformed = True
569
+
570
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
571
+
572
+ if not is_xformers_available():
573
+ print("Here is how to install it")
574
+ raise ModuleNotFoundError(
575
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
576
+ " xformers",
577
+ name="xformers",
578
+ )
579
+ elif not torch.cuda.is_available():
580
+ raise ValueError(
581
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
582
+ " available for GPU "
583
+ )
584
+ else:
585
+ try:
586
+ # Make sure we can run the memory efficient attention
587
+ _ = xformers.ops.memory_efficient_attention(
588
+ torch.randn((1, 2, 40), device="cuda"),
589
+ torch.randn((1, 2, 40), device="cuda"),
590
+ torch.randn((1, 2, 40), device="cuda"),
591
+ )
592
+ except Exception as e:
593
+ raise e
594
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
595
+ if self.attn2 is not None:
596
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
597
+
598
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
599
+ # SparseCausal-Attention
600
+ norm_hidden_states = (
601
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
602
+ )
603
+
604
+ if self.only_cross_attention:
605
+ hidden_states = (
606
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
607
+ )
608
+ else:
609
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
610
+
611
+ if self.attn2 is not None:
612
+ # Cross-Attention
613
+ norm_hidden_states = (
614
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
615
+ )
616
+ hidden_states = (
617
+ self.attn2(
618
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ip_hidden_states=ip_hidden_states
619
+ )
620
+ + hidden_states
621
+ )
622
+
623
+ # Temporal Attention
624
+ if self.training:
625
+ d = hidden_states.shape[1]
626
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
627
+ hidden_states_video = hidden_states[:, :video_length, :]
628
+ hidden_states_image = hidden_states[:, video_length:, :]
629
+ norm_hidden_states_video = (
630
+ self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
631
+ )
632
+ hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
633
+
634
+ # Temporal Cross Attention
635
+ if self.tca_transformed is True:
636
+ hidden_states_video = rearrange(hidden_states_video, "(b d) f c -> b d (f c)", d=d).contiguous()
637
+ norm_hidden_states_video = (
638
+ self.cross_norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.cross_norm_temp(hidden_states_video)
639
+ )
640
+ temp_encoder_hidden_states = rearrange(encoder_hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
641
+ temp_encoder_hidden_states = temp_encoder_hidden_states[:, 0:1].squeeze(dim=1)
642
+ hidden_states_video = self.cross_attn_temp(norm_hidden_states_video, encoder_hidden_states=temp_encoder_hidden_states, attention_mask=attention_mask) + hidden_states_video
643
+ hidden_states_video = rearrange(hidden_states_video, "b d (f c) -> (b d) f c", f=video_length).contiguous()
644
+
645
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
646
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
647
+ else:
648
+ d = hidden_states.shape[1]
649
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
650
+ norm_hidden_states = (
651
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
652
+ )
653
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
654
+
655
+ # Temporal Cross Attention
656
+ if self.tca_transformed is True:
657
+ hidden_states = rearrange(hidden_states, "(b d) f c -> b d (f c)", d=d).contiguous()
658
+ norm_hidden_states = (
659
+ self.cross_norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.cross_norm_temp(hidden_states)
660
+ )
661
+ if encoder_temporal_hidden_states is not None:
662
+ encoder_hidden_states = encoder_temporal_hidden_states
663
+ temp_encoder_hidden_states = rearrange(encoder_hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
664
+ temp_encoder_hidden_states = temp_encoder_hidden_states[:, 0:1].squeeze(dim=1)
665
+ hidden_states = self.cross_attn_temp(norm_hidden_states, encoder_hidden_states=temp_encoder_hidden_states, attention_mask=attention_mask) + hidden_states
666
+ hidden_states = rearrange(hidden_states, "b d (f c) -> (b f) d c", f=video_length + use_image_num, d=d).contiguous()
667
+ else:
668
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
669
+
670
+ # Feed-forward
671
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
672
+
673
+ return hidden_states
674
+
675
+
676
+ class SparseCausalAttention(CrossAttention):
677
+ def forward_video(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
678
+ batch_size, sequence_length, _ = hidden_states.shape
679
+
680
+ encoder_hidden_states = encoder_hidden_states
681
+
682
+ if self.group_norm is not None:
683
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
684
+
685
+ query = self.to_q(hidden_states)
686
+ dim = query.shape[-1]
687
+ query = self.reshape_heads_to_batch_dim(query)
688
+
689
+ if self.added_kv_proj_dim is not None:
690
+ raise NotImplementedError
691
+
692
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
693
+ key = self.to_k(encoder_hidden_states)
694
+ value = self.to_v(encoder_hidden_states)
695
+
696
+ former_frame_index = torch.arange(video_length) - 1
697
+ former_frame_index[0] = 0
698
+
699
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous()
700
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
701
+ key = rearrange(key, "b f d c -> (b f) d c").contiguous()
702
+
703
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous()
704
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
705
+ value = rearrange(value, "b f d c -> (b f) d c").contiguous()
706
+
707
+ key = self.reshape_heads_to_batch_dim(key)
708
+ value = self.reshape_heads_to_batch_dim(value)
709
+
710
+ if attention_mask is not None:
711
+ if attention_mask.shape[-1] != query.shape[1]:
712
+ target_length = query.shape[1]
713
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
714
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
715
+
716
+ # attention, what we cannot get enough of
717
+ if self._use_memory_efficient_attention_xformers:
718
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
719
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
720
+ hidden_states = hidden_states.to(query.dtype)
721
+ else:
722
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
723
+ hidden_states = self._attention(query, key, value, attention_mask)
724
+ else:
725
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
726
+
727
+ # linear proj
728
+ hidden_states = self.to_out[0](hidden_states)
729
+
730
+ # dropout
731
+ hidden_states = self.to_out[1](hidden_states)
732
+ return hidden_states
733
+
734
+ def forward_image(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
735
+ batch_size, sequence_length, _ = hidden_states.shape
736
+
737
+ encoder_hidden_states = encoder_hidden_states
738
+
739
+ if self.group_norm is not None:
740
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
741
+
742
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
743
+ dim = query.shape[-1]
744
+ if not self.use_relative_position:
745
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
746
+
747
+ if self.added_kv_proj_dim is not None:
748
+ key = self.to_k(hidden_states)
749
+ value = self.to_v(hidden_states)
750
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
751
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
752
+
753
+ key = self.reshape_heads_to_batch_dim(key)
754
+ value = self.reshape_heads_to_batch_dim(value)
755
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
756
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
757
+
758
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
759
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
760
+ else:
761
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
762
+ key = self.to_k(encoder_hidden_states)
763
+ value = self.to_v(encoder_hidden_states)
764
+
765
+ if not self.use_relative_position:
766
+ key = self.reshape_heads_to_batch_dim(key)
767
+ value = self.reshape_heads_to_batch_dim(value)
768
+
769
+ if attention_mask is not None:
770
+ if attention_mask.shape[-1] != query.shape[1]:
771
+ target_length = query.shape[1]
772
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
773
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
774
+
775
+ # attention, what we cannot get enough of
776
+ if self._use_memory_efficient_attention_xformers:
777
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
778
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
779
+ hidden_states = hidden_states.to(query.dtype)
780
+ else:
781
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
782
+ hidden_states = self._attention(query, key, value, attention_mask)
783
+ else:
784
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
785
+
786
+ # linear proj
787
+ hidden_states = self.to_out[0](hidden_states)
788
+
789
+ # dropout
790
+ hidden_states = self.to_out[1](hidden_states)
791
+ return hidden_states
792
+
793
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_image_num=None):
794
+ if self.training:
795
+ # print(use_image_num)
796
+ hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
797
+ hidden_states_video = hidden_states[:, :video_length, ...]
798
+ hidden_states_image = hidden_states[:, video_length:, ...]
799
+ hidden_states_video = rearrange(hidden_states_video, 'b f d c -> (b f) d c').contiguous()
800
+ hidden_states_image = rearrange(hidden_states_image, 'b f d c -> (b f) d c').contiguous()
801
+ hidden_states_video = self.forward_video(hidden_states=hidden_states_video,
802
+ encoder_hidden_states=encoder_hidden_states,
803
+ attention_mask=attention_mask,
804
+ video_length=video_length)
805
+ hidden_states_image = self.forward_image(hidden_states=hidden_states_image,
806
+ encoder_hidden_states=encoder_hidden_states,
807
+ attention_mask=attention_mask)
808
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=0)
809
+ return hidden_states
810
+ # exit()
811
+ else:
812
+ return self.forward_video(hidden_states=hidden_states,
813
+ encoder_hidden_states=encoder_hidden_states,
814
+ attention_mask=attention_mask,
815
+ video_length=video_length)
816
+
817
+ class TemporalAttention(CrossAttention):
818
+ def __init__(self,
819
+ query_dim: int,
820
+ cross_attention_dim: Optional[int] = None,
821
+ heads: int = 8,
822
+ dim_head: int = 64,
823
+ dropout: float = 0.0,
824
+ bias=False,
825
+ upcast_attention: bool = False,
826
+ upcast_softmax: bool = False,
827
+ added_kv_proj_dim: Optional[int] = None,
828
+ norm_num_groups: Optional[int] = None,
829
+ rotary_emb=None):
830
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
831
+ # relative time positional embeddings
832
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
833
+ self.rotary_emb = rotary_emb
834
+
835
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
836
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
837
+ batch_size, sequence_length, _ = hidden_states.shape
838
+
839
+ encoder_hidden_states = encoder_hidden_states
840
+
841
+ if self.group_norm is not None:
842
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
843
+
844
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
845
+ dim = query.shape[-1]
846
+
847
+ if self.added_kv_proj_dim is not None:
848
+ key = self.to_k(hidden_states)
849
+ value = self.to_v(hidden_states)
850
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
851
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
852
+
853
+ key = self.reshape_heads_to_batch_dim(key)
854
+ value = self.reshape_heads_to_batch_dim(value)
855
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
856
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
857
+
858
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
859
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
860
+ else:
861
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
862
+ key = self.to_k(encoder_hidden_states)
863
+ value = self.to_v(encoder_hidden_states)
864
+
865
+ if attention_mask is not None:
866
+ if attention_mask.shape[-1] != query.shape[1]:
867
+ target_length = query.shape[1]
868
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
869
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
870
+
871
+ # attention, what we cannot get enough of
872
+ if self._use_memory_efficient_attention_xformers:
873
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
874
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
875
+ hidden_states = hidden_states.to(query.dtype)
876
+ else:
877
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
878
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
879
+ else:
880
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
881
+
882
+ # linear proj
883
+ hidden_states = self.to_out[0](hidden_states)
884
+
885
+ # dropout
886
+ hidden_states = self.to_out[1](hidden_states)
887
+ return hidden_states
888
+
889
+
890
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
891
+ if self.upcast_attention:
892
+ query = query.float()
893
+ key = key.float()
894
+
895
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
896
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
897
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
898
+
899
+ # torch.baddbmm only accepte 3-D tensor
900
+ # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
901
+ # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
902
+ if exists(self.rotary_emb):
903
+ query = self.rotary_emb.rotate_queries_or_keys(query)
904
+ key = self.rotary_emb.rotate_queries_or_keys(key)
905
+
906
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
907
+
908
+ attention_scores = attention_scores + time_rel_pos_bias
909
+
910
+ if attention_mask is not None:
911
+ # add attention mask
912
+ attention_scores = attention_scores + attention_mask
913
+
914
+ # vdm
915
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
916
+
917
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
918
+ # print(attention_probs[0][0])
919
+
920
+ # cast back to the original dtype
921
+ attention_probs = attention_probs.to(value.dtype)
922
+
923
+ # compute attention output
924
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
925
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
926
+ return hidden_states
927
+
928
+ class RelativePositionBias(nn.Module):
929
+ def __init__(
930
+ self,
931
+ heads=8,
932
+ num_buckets=32,
933
+ max_distance=128,
934
+ ):
935
+ super().__init__()
936
+ self.num_buckets = num_buckets
937
+ self.max_distance = max_distance
938
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
939
+
940
+ @staticmethod
941
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
942
+ ret = 0
943
+ n = -relative_position
944
+
945
+ num_buckets //= 2
946
+ ret += (n < 0).long() * num_buckets
947
+ n = torch.abs(n)
948
+
949
+ max_exact = num_buckets // 2
950
+ is_small = n < max_exact
951
+
952
+ val_if_large = max_exact + (
953
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
954
+ ).long()
955
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
956
+
957
+ ret += torch.where(is_small, n, val_if_large)
958
+ return ret
959
+
960
+ def forward(self, n, device):
961
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
962
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
963
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
964
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
965
+ values = self.relative_attention_bias(rp_bucket)
966
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
models/clip.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch.nn as nn
3
+ from transformers import CLIPTokenizer, CLIPTextModel
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ """
9
+ Will encounter following warning:
10
+ - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
11
+ or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
12
+ - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
13
+ that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
14
+
15
+ https://github.com/CompVis/stable-diffusion/issues/97
16
+ according to this issue, this warning is safe.
17
+
18
+ This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
19
+ You can safely ignore the warning, it is not an error.
20
+
21
+ This clip usage is from U-ViT and same with Stable Diffusion.
22
+ """
23
+
24
+ class AbstractEncoder(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def encode(self, *args, **kwargs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class FrozenCLIPEmbedder(AbstractEncoder):
33
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
34
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
35
+ def __init__(self, path, device="cuda", max_length=77):
36
+ super().__init__()
37
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
38
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
39
+ self.device = device
40
+ self.max_length = max_length
41
+ self.freeze()
42
+
43
+ def freeze(self):
44
+ self.transformer = self.transformer.eval()
45
+ for param in self.parameters():
46
+ param.requires_grad = False
47
+
48
+ def forward(self, text):
49
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
50
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
51
+ tokens = batch_encoding["input_ids"].to(self.device)
52
+ outputs = self.transformer(input_ids=tokens)
53
+
54
+ z = outputs.last_hidden_state
55
+ return z
56
+
57
+ def encode(self, text):
58
+ return self(text)
59
+
60
+
61
+ class TextEmbedder(nn.Module):
62
+ """
63
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
64
+ """
65
+ def __init__(self, path, dropout_prob=0.1):
66
+ super().__init__()
67
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
68
+ self.dropout_prob = dropout_prob
69
+
70
+ def token_drop(self, text_prompts, force_drop_ids=None):
71
+ """
72
+ Drops text to enable classifier-free guidance.
73
+ """
74
+ if force_drop_ids is None:
75
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
76
+ else:
77
+ # TODO
78
+ drop_ids = force_drop_ids == 1
79
+ labels = list(numpy.where(drop_ids, "", text_prompts))
80
+ # print(labels)
81
+ return labels
82
+
83
+ def forward(self, text_prompts, train, force_drop_ids=None):
84
+ use_dropout = self.dropout_prob > 0
85
+ if (train and use_dropout) or (force_drop_ids is not None):
86
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
87
+ embeddings = self.text_encodder(text_prompts)
88
+ return embeddings
89
+
90
+
91
+ if __name__ == '__main__':
92
+
93
+ r"""
94
+ Returns:
95
+
96
+ Examples from CLIPTextModel:
97
+
98
+ ```python
99
+ >>> from transformers import AutoTokenizer, CLIPTextModel
100
+
101
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
102
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
103
+
104
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
105
+
106
+ >>> outputs = model(**inputs)
107
+ >>> last_hidden_state = outputs.last_hidden_state
108
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
109
+ ```"""
110
+
111
+ import torch
112
+
113
+ device = "cuda" if torch.cuda.is_available() else "cpu"
114
+
115
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
116
+ dropout_prob=0.00001).to(device)
117
+
118
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
119
+ # text_prompt = ('None', 'None', 'None')
120
+ output = text_encoder(text_prompts=text_prompt, train=False)
121
+ # print(output)
122
+ print(output.shape)
123
+ # print(output.shape)
models/resnet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ class InflatedConv3d(nn.Conv2d):
14
+ def forward(self, x):
15
+ video_length = x.shape[2]
16
+
17
+ x = rearrange(x, "b c f h w -> (b f) c h w")
18
+ x = super().forward(x)
19
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
20
+
21
+ return x
22
+
23
+
24
+ class Upsample3D(nn.Module):
25
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
26
+ super().__init__()
27
+ self.channels = channels
28
+ self.out_channels = out_channels or channels
29
+ self.use_conv = use_conv
30
+ self.use_conv_transpose = use_conv_transpose
31
+ self.name = name
32
+
33
+ conv = None
34
+ if use_conv_transpose:
35
+ raise NotImplementedError
36
+ elif use_conv:
37
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
38
+
39
+ if name == "conv":
40
+ self.conv = conv
41
+ else:
42
+ self.Conv2d_0 = conv
43
+
44
+ def forward(self, hidden_states, output_size=None):
45
+ assert hidden_states.shape[1] == self.channels
46
+
47
+ if self.use_conv_transpose:
48
+ raise NotImplementedError
49
+
50
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
51
+ dtype = hidden_states.dtype
52
+ if dtype == torch.bfloat16:
53
+ hidden_states = hidden_states.to(torch.float32)
54
+
55
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
56
+ if hidden_states.shape[0] >= 64:
57
+ hidden_states = hidden_states.contiguous()
58
+
59
+ # if `output_size` is passed we force the interpolation output
60
+ # size and do not make use of `scale_factor=2`
61
+ if output_size is None:
62
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
63
+ else:
64
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
65
+
66
+ # If the input is bfloat16, we cast back to bfloat16
67
+ if dtype == torch.bfloat16:
68
+ hidden_states = hidden_states.to(dtype)
69
+
70
+ if self.use_conv:
71
+ if self.name == "conv":
72
+ hidden_states = self.conv(hidden_states)
73
+ else:
74
+ hidden_states = self.Conv2d_0(hidden_states)
75
+
76
+ return hidden_states
77
+
78
+
79
+ class Downsample3D(nn.Module):
80
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
81
+ super().__init__()
82
+ self.channels = channels
83
+ self.out_channels = out_channels or channels
84
+ self.use_conv = use_conv
85
+ self.padding = padding
86
+ stride = 2
87
+ self.name = name
88
+
89
+ if use_conv:
90
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ if name == "conv":
95
+ self.Conv2d_0 = conv
96
+ self.conv = conv
97
+ elif name == "Conv2d_0":
98
+ self.conv = conv
99
+ else:
100
+ self.conv = conv
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
models/unet.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.split(sys.path[0])[0])
9
+
10
+ import math
11
+ import json
12
+ import torch
13
+ import einops
14
+ import torch.nn as nn
15
+ import torch.utils.checkpoint
16
+
17
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
18
+ from diffusers.utils import BaseOutput, logging
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+ from einops import rearrange
21
+
22
+ try:
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ except:
25
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
26
+
27
+ try:
28
+ from .unet_blocks import (
29
+ CrossAttnDownBlock3D,
30
+ CrossAttnUpBlock3D,
31
+ DownBlock3D,
32
+ UNetMidBlock3DCrossAttn,
33
+ UpBlock3D,
34
+ get_down_block,
35
+ get_up_block,
36
+ )
37
+ from .resnet import InflatedConv3d
38
+ except:
39
+ from unet_blocks import (
40
+ CrossAttnDownBlock3D,
41
+ CrossAttnUpBlock3D,
42
+ DownBlock3D,
43
+ UNetMidBlock3DCrossAttn,
44
+ UpBlock3D,
45
+ get_down_block,
46
+ get_up_block,
47
+ )
48
+ from resnet import InflatedConv3d
49
+
50
+ from rotary_embedding_torch import RotaryEmbedding
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ class RelativePositionBias(nn.Module):
55
+ def __init__(
56
+ self,
57
+ heads=8,
58
+ num_buckets=32,
59
+ max_distance=128,
60
+ ):
61
+ super().__init__()
62
+ self.num_buckets = num_buckets
63
+ self.max_distance = max_distance
64
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
65
+
66
+ @staticmethod
67
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
68
+ ret = 0
69
+ n = -relative_position
70
+
71
+ num_buckets //= 2
72
+ ret += (n < 0).long() * num_buckets
73
+ n = torch.abs(n)
74
+
75
+ max_exact = num_buckets // 2
76
+ is_small = n < max_exact
77
+
78
+ val_if_large = max_exact + (
79
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
80
+ ).long()
81
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
82
+
83
+ ret += torch.where(is_small, n, val_if_large)
84
+ return ret
85
+
86
+ def forward(self, n, device):
87
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
88
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
89
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
90
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
91
+ values = self.relative_attention_bias(rp_bucket)
92
+ return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
93
+
94
+ @dataclass
95
+ class UNet3DConditionOutput(BaseOutput):
96
+ sample: torch.FloatTensor
97
+
98
+
99
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
100
+ _supports_gradient_checkpointing = True
101
+
102
+ @register_to_config
103
+ def __init__(
104
+ self,
105
+ sample_size: Optional[int] = None, # 64
106
+ in_channels: int = 4,
107
+ out_channels: int = 4,
108
+ center_input_sample: bool = False,
109
+ flip_sin_to_cos: bool = True,
110
+ freq_shift: int = 0,
111
+ down_block_types: Tuple[str] = (
112
+ "CrossAttnDownBlock3D",
113
+ "CrossAttnDownBlock3D",
114
+ "CrossAttnDownBlock3D",
115
+ "DownBlock3D",
116
+ ),
117
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
118
+ up_block_types: Tuple[str] = (
119
+ "UpBlock3D",
120
+ "CrossAttnUpBlock3D",
121
+ "CrossAttnUpBlock3D",
122
+ "CrossAttnUpBlock3D"
123
+ ),
124
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
125
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
126
+ layers_per_block: int = 2,
127
+ downsample_padding: int = 1,
128
+ mid_block_scale_factor: float = 1,
129
+ act_fn: str = "silu",
130
+ norm_num_groups: int = 32,
131
+ norm_eps: float = 1e-5,
132
+ cross_attention_dim: int = 1280,
133
+ attention_head_dim: Union[int, Tuple[int]] = 8,
134
+ dual_cross_attention: bool = False,
135
+ use_linear_projection: bool = False,
136
+ class_embed_type: Optional[str] = None,
137
+ num_class_embeds: Optional[int] = None,
138
+ upcast_attention: bool = False,
139
+ resnet_time_scale_shift: str = "default",
140
+ use_first_frame: bool = False,
141
+ use_relative_position: bool = False,
142
+ ):
143
+ super().__init__()
144
+
145
+ # print(use_first_frame)
146
+
147
+ self.sample_size = sample_size
148
+ time_embed_dim = block_out_channels[0] * 4
149
+
150
+ # input
151
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
152
+
153
+ # time
154
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
155
+ timestep_input_dim = block_out_channels[0]
156
+
157
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
158
+
159
+ # class embedding
160
+ if class_embed_type is None and num_class_embeds is not None:
161
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
162
+ elif class_embed_type == "timestep":
163
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
164
+ elif class_embed_type == "identity":
165
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
166
+ else:
167
+ self.class_embedding = None
168
+
169
+ self.down_blocks = nn.ModuleList([])
170
+ self.mid_block = None
171
+ self.up_blocks = nn.ModuleList([])
172
+
173
+ if isinstance(only_cross_attention, bool):
174
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
175
+
176
+ if isinstance(attention_head_dim, int):
177
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
178
+
179
+ rotary_emb = RotaryEmbedding(32)
180
+
181
+ # down
182
+ output_channel = block_out_channels[0]
183
+ for i, down_block_type in enumerate(down_block_types):
184
+ input_channel = output_channel
185
+ output_channel = block_out_channels[i]
186
+ is_final_block = i == len(block_out_channels) - 1
187
+
188
+ down_block = get_down_block(
189
+ down_block_type,
190
+ num_layers=layers_per_block,
191
+ in_channels=input_channel,
192
+ out_channels=output_channel,
193
+ temb_channels=time_embed_dim,
194
+ add_downsample=not is_final_block,
195
+ resnet_eps=norm_eps,
196
+ resnet_act_fn=act_fn,
197
+ resnet_groups=norm_num_groups,
198
+ cross_attention_dim=cross_attention_dim,
199
+ attn_num_head_channels=attention_head_dim[i],
200
+ downsample_padding=downsample_padding,
201
+ dual_cross_attention=dual_cross_attention,
202
+ use_linear_projection=use_linear_projection,
203
+ only_cross_attention=only_cross_attention[i],
204
+ upcast_attention=upcast_attention,
205
+ resnet_time_scale_shift=resnet_time_scale_shift,
206
+ use_first_frame=use_first_frame,
207
+ use_relative_position=use_relative_position,
208
+ rotary_emb=rotary_emb,
209
+ )
210
+ self.down_blocks.append(down_block)
211
+
212
+ # mid
213
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
214
+ self.mid_block = UNetMidBlock3DCrossAttn(
215
+ in_channels=block_out_channels[-1],
216
+ temb_channels=time_embed_dim,
217
+ resnet_eps=norm_eps,
218
+ resnet_act_fn=act_fn,
219
+ output_scale_factor=mid_block_scale_factor,
220
+ resnet_time_scale_shift=resnet_time_scale_shift,
221
+ cross_attention_dim=cross_attention_dim,
222
+ attn_num_head_channels=attention_head_dim[-1],
223
+ resnet_groups=norm_num_groups,
224
+ dual_cross_attention=dual_cross_attention,
225
+ use_linear_projection=use_linear_projection,
226
+ upcast_attention=upcast_attention,
227
+ use_first_frame=use_first_frame,
228
+ use_relative_position=use_relative_position,
229
+ rotary_emb=rotary_emb,
230
+ )
231
+ else:
232
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
233
+
234
+ # count how many layers upsample the videos
235
+ self.num_upsamplers = 0
236
+
237
+ # up
238
+ reversed_block_out_channels = list(reversed(block_out_channels))
239
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
240
+ only_cross_attention = list(reversed(only_cross_attention))
241
+ output_channel = reversed_block_out_channels[0]
242
+ for i, up_block_type in enumerate(up_block_types):
243
+ is_final_block = i == len(block_out_channels) - 1
244
+
245
+ prev_output_channel = output_channel
246
+ output_channel = reversed_block_out_channels[i]
247
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
248
+
249
+ # add upsample block for all BUT final layer
250
+ if not is_final_block:
251
+ add_upsample = True
252
+ self.num_upsamplers += 1
253
+ else:
254
+ add_upsample = False
255
+
256
+ up_block = get_up_block(
257
+ up_block_type,
258
+ num_layers=layers_per_block + 1,
259
+ in_channels=input_channel,
260
+ out_channels=output_channel,
261
+ prev_output_channel=prev_output_channel,
262
+ temb_channels=time_embed_dim,
263
+ add_upsample=add_upsample,
264
+ resnet_eps=norm_eps,
265
+ resnet_act_fn=act_fn,
266
+ resnet_groups=norm_num_groups,
267
+ cross_attention_dim=cross_attention_dim,
268
+ attn_num_head_channels=reversed_attention_head_dim[i],
269
+ dual_cross_attention=dual_cross_attention,
270
+ use_linear_projection=use_linear_projection,
271
+ only_cross_attention=only_cross_attention[i],
272
+ upcast_attention=upcast_attention,
273
+ resnet_time_scale_shift=resnet_time_scale_shift,
274
+ use_first_frame=use_first_frame,
275
+ use_relative_position=use_relative_position,
276
+ rotary_emb=rotary_emb,
277
+ )
278
+ self.up_blocks.append(up_block)
279
+ prev_output_channel = output_channel
280
+
281
+ # out
282
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
283
+ self.conv_act = nn.SiLU()
284
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
285
+
286
+ # relative time positional embeddings
287
+ self.use_relative_position = use_relative_position
288
+ if self.use_relative_position:
289
+ self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
290
+
291
+ def set_attention_slice(self, slice_size):
292
+ r"""
293
+ Enable sliced attention computation.
294
+
295
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
296
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
297
+
298
+ Args:
299
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
300
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
301
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
302
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
303
+ must be a multiple of `slice_size`.
304
+ """
305
+ sliceable_head_dims = []
306
+
307
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
308
+ if hasattr(module, "set_attention_slice"):
309
+ sliceable_head_dims.append(module.sliceable_head_dim)
310
+
311
+ for child in module.children():
312
+ fn_recursive_retrieve_slicable_dims(child)
313
+
314
+ # retrieve number of attention layers
315
+ for module in self.children():
316
+ fn_recursive_retrieve_slicable_dims(module)
317
+
318
+ num_slicable_layers = len(sliceable_head_dims)
319
+
320
+ if slice_size == "auto":
321
+ # half the attention head size is usually a good trade-off between
322
+ # speed and memory
323
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
324
+ elif slice_size == "max":
325
+ # make smallest slice possible
326
+ slice_size = num_slicable_layers * [1]
327
+
328
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
329
+
330
+ if len(slice_size) != len(sliceable_head_dims):
331
+ raise ValueError(
332
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
333
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
334
+ )
335
+
336
+ for i in range(len(slice_size)):
337
+ size = slice_size[i]
338
+ dim = sliceable_head_dims[i]
339
+ if size is not None and size > dim:
340
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
341
+
342
+ # Recursively walk through all the children.
343
+ # Any children which exposes the set_attention_slice method
344
+ # gets the message
345
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
346
+ if hasattr(module, "set_attention_slice"):
347
+ module.set_attention_slice(slice_size.pop())
348
+
349
+ for child in module.children():
350
+ fn_recursive_set_attention_slice(child, slice_size)
351
+
352
+ reversed_slice_size = list(reversed(slice_size))
353
+ for module in self.children():
354
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
355
+
356
+ def _set_gradient_checkpointing(self, module, value=False):
357
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
358
+ module.gradient_checkpointing = value
359
+
360
+ def forward(
361
+ self,
362
+ sample: torch.FloatTensor,
363
+ timestep: Union[torch.Tensor, float, int],
364
+ encoder_hidden_states: torch.Tensor = None,
365
+ class_labels: Optional[torch.Tensor] = None,
366
+ attention_mask: Optional[torch.Tensor] = None,
367
+ use_image_num: int = 0,
368
+ return_dict: bool = True,
369
+ ip_hidden_states = None,
370
+ encoder_temporal_hidden_states = None
371
+ ) -> Union[UNet3DConditionOutput, Tuple]:
372
+ r"""
373
+ Args:
374
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
375
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
376
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
377
+ return_dict (`bool`, *optional*, defaults to `True`):
378
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
379
+
380
+ Returns:
381
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
382
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
383
+ returning a tuple, the first element is the sample tensor.
384
+ """
385
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
386
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
387
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
388
+ # on the fly if necessary.
389
+ if ip_hidden_states is not None:
390
+ b = ip_hidden_states.shape[0]
391
+ ip_hidden_states = rearrange(ip_hidden_states, 'b n c -> (b n) c')
392
+ ip_hidden_states = self.image_proj_model(ip_hidden_states)
393
+ ip_hidden_states = rearrange(ip_hidden_states, '(b n) m c -> b n m c', b=b)
394
+ default_overall_up_factor = 2**self.num_upsamplers
395
+
396
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
397
+ forward_upsample_size = False
398
+ upsample_size = None
399
+
400
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
401
+ logger.info("Forward upsample size to force interpolation output size.")
402
+ forward_upsample_size = True
403
+
404
+ # prepare attention_mask
405
+ if attention_mask is not None:
406
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
407
+ attention_mask = attention_mask.unsqueeze(1)
408
+
409
+ # center input if necessary
410
+ if self.config.center_input_sample:
411
+ sample = 2 * sample - 1.0
412
+
413
+ # time
414
+ timesteps = timestep
415
+ if not torch.is_tensor(timesteps):
416
+ # This would be a good case for the `match` statement (Python 3.10+)
417
+ is_mps = sample.device.type == "mps"
418
+ if isinstance(timestep, float):
419
+ dtype = torch.float32 if is_mps else torch.float64
420
+ else:
421
+ dtype = torch.int32 if is_mps else torch.int64
422
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
423
+ elif len(timesteps.shape) == 0:
424
+ timesteps = timesteps[None].to(sample.device)
425
+
426
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
427
+ timesteps = timesteps.expand(sample.shape[0])
428
+
429
+ t_emb = self.time_proj(timesteps)
430
+
431
+ # timesteps does not contain any weights and will always return f32 tensors
432
+ # but time_embedding might actually be running in fp16. so we need to cast here.
433
+ # there might be better ways to encapsulate this.
434
+ t_emb = t_emb.to(dtype=self.dtype)
435
+ emb = self.time_embedding(t_emb)
436
+
437
+ if self.class_embedding is not None:
438
+ if class_labels is None:
439
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
440
+
441
+ if self.config.class_embed_type == "timestep":
442
+ class_labels = self.time_proj(class_labels)
443
+
444
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
445
+ # print(emb.shape) # torch.Size([3, 1280])
446
+ # print(class_emb.shape) # torch.Size([3, 1280])
447
+ emb = emb + class_emb
448
+
449
+ if self.use_relative_position:
450
+ frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
451
+ else:
452
+ frame_rel_pos_bias = None
453
+
454
+ # pre-process
455
+ sample = self.conv_in(sample)
456
+
457
+ # down
458
+ down_block_res_samples = (sample,)
459
+ for downsample_block in self.down_blocks:
460
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
461
+ sample, res_samples = downsample_block(
462
+ hidden_states=sample,
463
+ temb=emb,
464
+ encoder_hidden_states=encoder_hidden_states,
465
+ attention_mask=attention_mask,
466
+ use_image_num=use_image_num,
467
+ ip_hidden_states=ip_hidden_states,
468
+ encoder_temporal_hidden_states=encoder_temporal_hidden_states
469
+ )
470
+ else:
471
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
472
+
473
+ down_block_res_samples += res_samples
474
+
475
+ # mid
476
+ sample = self.mid_block(
477
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states
478
+ )
479
+
480
+ # up
481
+ for i, upsample_block in enumerate(self.up_blocks):
482
+ is_final_block = i == len(self.up_blocks) - 1
483
+
484
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
485
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
486
+
487
+ # if we have not reached the final block and need to forward the
488
+ # upsample size, we do it here
489
+ if not is_final_block and forward_upsample_size:
490
+ upsample_size = down_block_res_samples[-1].shape[2:]
491
+
492
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
493
+ sample = upsample_block(
494
+ hidden_states=sample,
495
+ temb=emb,
496
+ res_hidden_states_tuple=res_samples,
497
+ encoder_hidden_states=encoder_hidden_states,
498
+ upsample_size=upsample_size,
499
+ attention_mask=attention_mask,
500
+ use_image_num=use_image_num,
501
+ ip_hidden_states=ip_hidden_states,
502
+ encoder_temporal_hidden_states=encoder_temporal_hidden_states
503
+ )
504
+ else:
505
+ sample = upsample_block(
506
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
507
+ )
508
+ # post-process
509
+ sample = self.conv_norm_out(sample)
510
+ sample = self.conv_act(sample)
511
+ sample = self.conv_out(sample)
512
+ # print(sample.shape)
513
+
514
+ if not return_dict:
515
+ return (sample,)
516
+ sample = UNet3DConditionOutput(sample=sample)
517
+ return sample
518
+
519
+ def forward_with_cfg(self,
520
+ x,
521
+ t,
522
+ encoder_hidden_states = None,
523
+ class_labels: Optional[torch.Tensor] = None,
524
+ cfg_scale=4.0,
525
+ use_fp16=False,
526
+ ip_hidden_states = None):
527
+ """
528
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
529
+ """
530
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
531
+ half = x[: len(x) // 2]
532
+ combined = torch.cat([half, half], dim=0)
533
+ if use_fp16:
534
+ combined = combined.to(dtype=torch.float16)
535
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels, ip_hidden_states=ip_hidden_states).sample
536
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
537
+ # three channels by default. The standard approach to cfg applies it to all channels.
538
+ # This can be done by uncommenting the following line and commenting-out the line following that.
539
+ eps, rest = model_out[:, :4], model_out[:, 4:]
540
+ # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
541
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
542
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
543
+ eps = torch.cat([half_eps, half_eps], dim=0)
544
+ return torch.cat([eps, rest], dim=1)
545
+
546
+ @classmethod
547
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False):
548
+ if subfolder is not None:
549
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
550
+
551
+
552
+ # the content of the config file
553
+ # {
554
+ # "_class_name": "UNet2DConditionModel",
555
+ # "_diffusers_version": "0.2.2",
556
+ # "act_fn": "silu",
557
+ # "attention_head_dim": 8,
558
+ # "block_out_channels": [
559
+ # 320,
560
+ # 640,
561
+ # 1280,
562
+ # 1280
563
+ # ],
564
+ # "center_input_sample": false,
565
+ # "cross_attention_dim": 768,
566
+ # "down_block_types": [
567
+ # "CrossAttnDownBlock2D",
568
+ # "CrossAttnDownBlock2D",
569
+ # "CrossAttnDownBlock2D",
570
+ # "DownBlock2D"
571
+ # ],
572
+ # "downsample_padding": 1,
573
+ # "flip_sin_to_cos": true,
574
+ # "freq_shift": 0,
575
+ # "in_channels": 4,
576
+ # "layers_per_block": 2,
577
+ # "mid_block_scale_factor": 1,
578
+ # "norm_eps": 1e-05,
579
+ # "norm_num_groups": 32,
580
+ # "out_channels": 4,
581
+ # "sample_size": 64,
582
+ # "up_block_types": [
583
+ # "UpBlock2D",
584
+ # "CrossAttnUpBlock2D",
585
+ # "CrossAttnUpBlock2D",
586
+ # "CrossAttnUpBlock2D"
587
+ # ]
588
+ # }
589
+ config_file = os.path.join(pretrained_model_path, 'config.json')
590
+ if not os.path.isfile(config_file):
591
+ raise RuntimeError(f"{config_file} does not exist")
592
+ with open(config_file, "r") as f:
593
+ config = json.load(f)
594
+ config["_class_name"] = cls.__name__
595
+ config["down_block_types"] = [
596
+ "CrossAttnDownBlock3D",
597
+ "CrossAttnDownBlock3D",
598
+ "CrossAttnDownBlock3D",
599
+ "DownBlock3D"
600
+ ]
601
+ config["up_block_types"] = [
602
+ "UpBlock3D",
603
+ "CrossAttnUpBlock3D",
604
+ "CrossAttnUpBlock3D",
605
+ "CrossAttnUpBlock3D"
606
+ ]
607
+
608
+ # config["use_first_frame"] = True
609
+
610
+ config["use_first_frame"] = False
611
+ if use_concat:
612
+ config["in_channels"] = 9
613
+ # config["use_relative_position"] = True
614
+
615
+ # # tmp
616
+ # config["class_embed_type"] = "timestep"
617
+ # config["num_class_embeds"] = 100
618
+
619
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
620
+
621
+ # {'_class_name': 'UNet3DConditionModel',
622
+ # '_diffusers_version': '0.2.2',
623
+ # 'act_fn': 'silu',
624
+ # 'attention_head_dim': 8,
625
+ # 'block_out_channels': [320, 640, 1280, 1280],
626
+ # 'center_input_sample': False,
627
+ # 'cross_attention_dim': 768,
628
+ # 'down_block_types':
629
+ # ['CrossAttnDownBlock3D',
630
+ # 'CrossAttnDownBlock3D',
631
+ # 'CrossAttnDownBlock3D',
632
+ # 'DownBlock3D'],
633
+ # 'downsample_padding': 1,
634
+ # 'flip_sin_to_cos': True,
635
+ # 'freq_shift': 0,
636
+ # 'in_channels': 4,
637
+ # 'layers_per_block': 2,
638
+ # 'mid_block_scale_factor': 1,
639
+ # 'norm_eps': 1e-05,
640
+ # 'norm_num_groups': 32,
641
+ # 'out_channels': 4,
642
+ # 'sample_size': 64,
643
+ # 'up_block_types':
644
+ # ['UpBlock3D',
645
+ # 'CrossAttnUpBlock3D',
646
+ # 'CrossAttnUpBlock3D',
647
+ # 'CrossAttnUpBlock3D']}
648
+
649
+ model = cls.from_config(config)
650
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
651
+ if not os.path.isfile(model_file):
652
+ raise RuntimeError(f"{model_file} does not exist")
653
+ state_dict = torch.load(model_file, map_location="cpu")
654
+
655
+ if use_concat:
656
+ new_state_dict = {}
657
+ conv_in_weight = state_dict["conv_in.weight"]
658
+ new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
659
+
660
+ for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
661
+ new_conv_weight[:, j] = conv_in_weight[:, i]
662
+ new_state_dict["conv_in.weight"] = new_conv_weight
663
+ new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
664
+ for k, v in model.state_dict().items():
665
+ # print(k)
666
+ if '_temp.' in k:
667
+ new_state_dict.update({k: v})
668
+ if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
669
+ k = k.replace('attn_fcross', 'attn1')
670
+ state_dict.update({k: state_dict[k]})
671
+ if 'norm_fcross' in k:
672
+ k = k.replace('norm_fcross', 'norm1')
673
+ state_dict.update({k: state_dict[k]})
674
+
675
+ if 'conv_in' in k:
676
+ continue
677
+ else:
678
+ new_state_dict[k] = v
679
+ # # tmp
680
+ # if 'class_embedding' in k:
681
+ # state_dict.update({k: v})
682
+ # breakpoint()
683
+ model.load_state_dict(new_state_dict)
684
+ else:
685
+ for k, v in model.state_dict().items():
686
+ # print(k)
687
+ if '_temp' in k:
688
+ state_dict.update({k: v})
689
+ if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
690
+ k = k.replace('attn_fcross', 'attn1')
691
+ state_dict.update({k: state_dict[k]})
692
+ if 'norm_fcross' in k:
693
+ k = k.replace('norm_fcross', 'norm1')
694
+ state_dict.update({k: state_dict[k]})
695
+
696
+ model.load_state_dict(state_dict)
697
+
698
+ return model
699
+
models/unet_blocks.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ try:
10
+ from .attention import Transformer3DModel
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ except:
13
+ from attention import Transformer3DModel
14
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
15
+
16
+
17
+ def get_down_block(
18
+ down_block_type,
19
+ num_layers,
20
+ in_channels,
21
+ out_channels,
22
+ temb_channels,
23
+ add_downsample,
24
+ resnet_eps,
25
+ resnet_act_fn,
26
+ attn_num_head_channels,
27
+ resnet_groups=None,
28
+ cross_attention_dim=None,
29
+ downsample_padding=None,
30
+ dual_cross_attention=False,
31
+ use_linear_projection=False,
32
+ only_cross_attention=False,
33
+ upcast_attention=False,
34
+ resnet_time_scale_shift="default",
35
+ use_first_frame=False,
36
+ use_relative_position=False,
37
+ rotary_emb=False,
38
+ ):
39
+ # print(down_block_type)
40
+ # print(use_first_frame)
41
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
42
+ if down_block_type == "DownBlock3D":
43
+ return DownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift,
54
+ )
55
+ elif down_block_type == "CrossAttnDownBlock3D":
56
+ if cross_attention_dim is None:
57
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
+ return CrossAttnDownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ cross_attention_dim=cross_attention_dim,
69
+ attn_num_head_channels=attn_num_head_channels,
70
+ dual_cross_attention=dual_cross_attention,
71
+ use_linear_projection=use_linear_projection,
72
+ only_cross_attention=only_cross_attention,
73
+ upcast_attention=upcast_attention,
74
+ resnet_time_scale_shift=resnet_time_scale_shift,
75
+ use_first_frame=use_first_frame,
76
+ use_relative_position=use_relative_position,
77
+ rotary_emb=rotary_emb,
78
+ )
79
+ raise ValueError(f"{down_block_type} does not exist.")
80
+
81
+
82
+ def get_up_block(
83
+ up_block_type,
84
+ num_layers,
85
+ in_channels,
86
+ out_channels,
87
+ prev_output_channel,
88
+ temb_channels,
89
+ add_upsample,
90
+ resnet_eps,
91
+ resnet_act_fn,
92
+ attn_num_head_channels,
93
+ resnet_groups=None,
94
+ cross_attention_dim=None,
95
+ dual_cross_attention=False,
96
+ use_linear_projection=False,
97
+ only_cross_attention=False,
98
+ upcast_attention=False,
99
+ resnet_time_scale_shift="default",
100
+ use_first_frame=False,
101
+ use_relative_position=False,
102
+ rotary_emb=False,
103
+ ):
104
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
105
+ if up_block_type == "UpBlock3D":
106
+ return UpBlock3D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ prev_output_channel=prev_output_channel,
111
+ temb_channels=temb_channels,
112
+ add_upsample=add_upsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ resnet_time_scale_shift=resnet_time_scale_shift,
117
+ )
118
+ elif up_block_type == "CrossAttnUpBlock3D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
121
+ return CrossAttnUpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ cross_attention_dim=cross_attention_dim,
132
+ attn_num_head_channels=attn_num_head_channels,
133
+ dual_cross_attention=dual_cross_attention,
134
+ use_linear_projection=use_linear_projection,
135
+ only_cross_attention=only_cross_attention,
136
+ upcast_attention=upcast_attention,
137
+ resnet_time_scale_shift=resnet_time_scale_shift,
138
+ use_first_frame=use_first_frame,
139
+ use_relative_position=use_relative_position,
140
+ rotary_emb=rotary_emb,
141
+ )
142
+ raise ValueError(f"{up_block_type} does not exist.")
143
+
144
+
145
+ class UNetMidBlock3DCrossAttn(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels: int,
149
+ temb_channels: int,
150
+ dropout: float = 0.0,
151
+ num_layers: int = 1,
152
+ resnet_eps: float = 1e-6,
153
+ resnet_time_scale_shift: str = "default",
154
+ resnet_act_fn: str = "swish",
155
+ resnet_groups: int = 32,
156
+ resnet_pre_norm: bool = True,
157
+ attn_num_head_channels=1,
158
+ output_scale_factor=1.0,
159
+ cross_attention_dim=1280,
160
+ dual_cross_attention=False,
161
+ use_linear_projection=False,
162
+ upcast_attention=False,
163
+ use_first_frame=False,
164
+ use_relative_position=False,
165
+ rotary_emb=False,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.has_cross_attention = True
170
+ self.attn_num_head_channels = attn_num_head_channels
171
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
172
+
173
+ # there is always at least one resnet
174
+ resnets = [
175
+ ResnetBlock3D(
176
+ in_channels=in_channels,
177
+ out_channels=in_channels,
178
+ temb_channels=temb_channels,
179
+ eps=resnet_eps,
180
+ groups=resnet_groups,
181
+ dropout=dropout,
182
+ time_embedding_norm=resnet_time_scale_shift,
183
+ non_linearity=resnet_act_fn,
184
+ output_scale_factor=output_scale_factor,
185
+ pre_norm=resnet_pre_norm,
186
+ )
187
+ ]
188
+ attentions = []
189
+
190
+ for _ in range(num_layers):
191
+ if dual_cross_attention:
192
+ raise NotImplementedError
193
+ attentions.append(
194
+ Transformer3DModel(
195
+ attn_num_head_channels,
196
+ in_channels // attn_num_head_channels,
197
+ in_channels=in_channels,
198
+ num_layers=1,
199
+ cross_attention_dim=cross_attention_dim,
200
+ norm_num_groups=resnet_groups,
201
+ use_linear_projection=use_linear_projection,
202
+ upcast_attention=upcast_attention,
203
+ use_first_frame=use_first_frame,
204
+ use_relative_position=use_relative_position,
205
+ rotary_emb=rotary_emb,
206
+ )
207
+ )
208
+ resnets.append(
209
+ ResnetBlock3D(
210
+ in_channels=in_channels,
211
+ out_channels=in_channels,
212
+ temb_channels=temb_channels,
213
+ eps=resnet_eps,
214
+ groups=resnet_groups,
215
+ dropout=dropout,
216
+ time_embedding_norm=resnet_time_scale_shift,
217
+ non_linearity=resnet_act_fn,
218
+ output_scale_factor=output_scale_factor,
219
+ pre_norm=resnet_pre_norm,
220
+ )
221
+ )
222
+
223
+ self.attentions = nn.ModuleList(attentions)
224
+ self.resnets = nn.ModuleList(resnets)
225
+
226
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
227
+ hidden_states = self.resnets[0](hidden_states, temb)
228
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
229
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
230
+ hidden_states = resnet(hidden_states, temb)
231
+
232
+ return hidden_states
233
+
234
+
235
+ class CrossAttnDownBlock3D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels: int,
239
+ out_channels: int,
240
+ temb_channels: int,
241
+ dropout: float = 0.0,
242
+ num_layers: int = 1,
243
+ resnet_eps: float = 1e-6,
244
+ resnet_time_scale_shift: str = "default",
245
+ resnet_act_fn: str = "swish",
246
+ resnet_groups: int = 32,
247
+ resnet_pre_norm: bool = True,
248
+ attn_num_head_channels=1,
249
+ cross_attention_dim=1280,
250
+ output_scale_factor=1.0,
251
+ downsample_padding=1,
252
+ add_downsample=True,
253
+ dual_cross_attention=False,
254
+ use_linear_projection=False,
255
+ only_cross_attention=False,
256
+ upcast_attention=False,
257
+ use_first_frame=False,
258
+ use_relative_position=False,
259
+ rotary_emb=False,
260
+ ):
261
+ super().__init__()
262
+ resnets = []
263
+ attentions = []
264
+
265
+ # print(use_first_frame)
266
+
267
+ self.has_cross_attention = True
268
+ self.attn_num_head_channels = attn_num_head_channels
269
+
270
+ for i in range(num_layers):
271
+ in_channels = in_channels if i == 0 else out_channels
272
+ resnets.append(
273
+ ResnetBlock3D(
274
+ in_channels=in_channels,
275
+ out_channels=out_channels,
276
+ temb_channels=temb_channels,
277
+ eps=resnet_eps,
278
+ groups=resnet_groups,
279
+ dropout=dropout,
280
+ time_embedding_norm=resnet_time_scale_shift,
281
+ non_linearity=resnet_act_fn,
282
+ output_scale_factor=output_scale_factor,
283
+ pre_norm=resnet_pre_norm,
284
+ )
285
+ )
286
+ if dual_cross_attention:
287
+ raise NotImplementedError
288
+ attentions.append(
289
+ Transformer3DModel(
290
+ attn_num_head_channels,
291
+ out_channels // attn_num_head_channels,
292
+ in_channels=out_channels,
293
+ num_layers=1,
294
+ cross_attention_dim=cross_attention_dim,
295
+ norm_num_groups=resnet_groups,
296
+ use_linear_projection=use_linear_projection,
297
+ only_cross_attention=only_cross_attention,
298
+ upcast_attention=upcast_attention,
299
+ use_first_frame=use_first_frame,
300
+ use_relative_position=use_relative_position,
301
+ rotary_emb=rotary_emb,
302
+ )
303
+ )
304
+ self.attentions = nn.ModuleList(attentions)
305
+ self.resnets = nn.ModuleList(resnets)
306
+
307
+ if add_downsample:
308
+ self.downsamplers = nn.ModuleList(
309
+ [
310
+ Downsample3D(
311
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
312
+ )
313
+ ]
314
+ )
315
+ else:
316
+ self.downsamplers = None
317
+
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
321
+ output_states = ()
322
+
323
+ for resnet, attn in zip(self.resnets, self.attentions):
324
+ if self.training and self.gradient_checkpointing:
325
+
326
+ def create_custom_forward(module, return_dict=None):
327
+ def custom_forward(*inputs):
328
+ if return_dict is not None:
329
+ return module(*inputs, return_dict=return_dict)
330
+ else:
331
+ return module(*inputs)
332
+
333
+ return custom_forward
334
+
335
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None, ip_hidden_states=None):
336
+ def custom_forward(*inputs):
337
+ if return_dict is not None:
338
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
339
+ else:
340
+ return module(*inputs, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
341
+
342
+ return custom_forward
343
+
344
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states),
347
+ hidden_states,
348
+ encoder_hidden_states,
349
+ )[0]
350
+ else:
351
+ hidden_states = resnet(hidden_states, temb)
352
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
353
+
354
+ output_states += (hidden_states,)
355
+
356
+ if self.downsamplers is not None:
357
+ for downsampler in self.downsamplers:
358
+ hidden_states = downsampler(hidden_states)
359
+
360
+ output_states += (hidden_states,)
361
+
362
+ return hidden_states, output_states
363
+
364
+
365
+ class DownBlock3D(nn.Module):
366
+ def __init__(
367
+ self,
368
+ in_channels: int,
369
+ out_channels: int,
370
+ temb_channels: int,
371
+ dropout: float = 0.0,
372
+ num_layers: int = 1,
373
+ resnet_eps: float = 1e-6,
374
+ resnet_time_scale_shift: str = "default",
375
+ resnet_act_fn: str = "swish",
376
+ resnet_groups: int = 32,
377
+ resnet_pre_norm: bool = True,
378
+ output_scale_factor=1.0,
379
+ add_downsample=True,
380
+ downsample_padding=1,
381
+ ):
382
+ super().__init__()
383
+ resnets = []
384
+
385
+ for i in range(num_layers):
386
+ in_channels = in_channels if i == 0 else out_channels
387
+ resnets.append(
388
+ ResnetBlock3D(
389
+ in_channels=in_channels,
390
+ out_channels=out_channels,
391
+ temb_channels=temb_channels,
392
+ eps=resnet_eps,
393
+ groups=resnet_groups,
394
+ dropout=dropout,
395
+ time_embedding_norm=resnet_time_scale_shift,
396
+ non_linearity=resnet_act_fn,
397
+ output_scale_factor=output_scale_factor,
398
+ pre_norm=resnet_pre_norm,
399
+ )
400
+ )
401
+
402
+ self.resnets = nn.ModuleList(resnets)
403
+
404
+ if add_downsample:
405
+ self.downsamplers = nn.ModuleList(
406
+ [
407
+ Downsample3D(
408
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
409
+ )
410
+ ]
411
+ )
412
+ else:
413
+ self.downsamplers = None
414
+
415
+ self.gradient_checkpointing = False
416
+
417
+ def forward(self, hidden_states, temb=None):
418
+ output_states = ()
419
+
420
+ for resnet in self.resnets:
421
+ if self.training and self.gradient_checkpointing:
422
+
423
+ def create_custom_forward(module):
424
+ def custom_forward(*inputs):
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
430
+ else:
431
+ hidden_states = resnet(hidden_states, temb)
432
+
433
+ output_states += (hidden_states,)
434
+
435
+ if self.downsamplers is not None:
436
+ for downsampler in self.downsamplers:
437
+ hidden_states = downsampler(hidden_states)
438
+
439
+ output_states += (hidden_states,)
440
+
441
+ return hidden_states, output_states
442
+
443
+
444
+ class CrossAttnUpBlock3D(nn.Module):
445
+ def __init__(
446
+ self,
447
+ in_channels: int,
448
+ out_channels: int,
449
+ prev_output_channel: int,
450
+ temb_channels: int,
451
+ dropout: float = 0.0,
452
+ num_layers: int = 1,
453
+ resnet_eps: float = 1e-6,
454
+ resnet_time_scale_shift: str = "default",
455
+ resnet_act_fn: str = "swish",
456
+ resnet_groups: int = 32,
457
+ resnet_pre_norm: bool = True,
458
+ attn_num_head_channels=1,
459
+ cross_attention_dim=1280,
460
+ output_scale_factor=1.0,
461
+ add_upsample=True,
462
+ dual_cross_attention=False,
463
+ use_linear_projection=False,
464
+ only_cross_attention=False,
465
+ upcast_attention=False,
466
+ use_first_frame=False,
467
+ use_relative_position=False,
468
+ rotary_emb=False
469
+ ):
470
+ super().__init__()
471
+ resnets = []
472
+ attentions = []
473
+
474
+ self.has_cross_attention = True
475
+ self.attn_num_head_channels = attn_num_head_channels
476
+
477
+ for i in range(num_layers):
478
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
479
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
480
+
481
+ resnets.append(
482
+ ResnetBlock3D(
483
+ in_channels=resnet_in_channels + res_skip_channels,
484
+ out_channels=out_channels,
485
+ temb_channels=temb_channels,
486
+ eps=resnet_eps,
487
+ groups=resnet_groups,
488
+ dropout=dropout,
489
+ time_embedding_norm=resnet_time_scale_shift,
490
+ non_linearity=resnet_act_fn,
491
+ output_scale_factor=output_scale_factor,
492
+ pre_norm=resnet_pre_norm,
493
+ )
494
+ )
495
+ if dual_cross_attention:
496
+ raise NotImplementedError
497
+ attentions.append(
498
+ Transformer3DModel(
499
+ attn_num_head_channels,
500
+ out_channels // attn_num_head_channels,
501
+ in_channels=out_channels,
502
+ num_layers=1,
503
+ cross_attention_dim=cross_attention_dim,
504
+ norm_num_groups=resnet_groups,
505
+ use_linear_projection=use_linear_projection,
506
+ only_cross_attention=only_cross_attention,
507
+ upcast_attention=upcast_attention,
508
+ use_first_frame=use_first_frame,
509
+ use_relative_position=use_relative_position,
510
+ rotary_emb=rotary_emb,
511
+ )
512
+ )
513
+
514
+ self.attentions = nn.ModuleList(attentions)
515
+ self.resnets = nn.ModuleList(resnets)
516
+
517
+ if add_upsample:
518
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
519
+ else:
520
+ self.upsamplers = None
521
+
522
+ self.gradient_checkpointing = False
523
+
524
+ def forward(
525
+ self,
526
+ hidden_states,
527
+ res_hidden_states_tuple,
528
+ temb=None,
529
+ encoder_hidden_states=None,
530
+ upsample_size=None,
531
+ attention_mask=None,
532
+ use_image_num=None,
533
+ ip_hidden_states=None,
534
+ encoder_temporal_hidden_states=None
535
+ ):
536
+ for resnet, attn in zip(self.resnets, self.attentions):
537
+ # pop res hidden states
538
+ res_hidden_states = res_hidden_states_tuple[-1]
539
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
540
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
541
+
542
+ if self.training and self.gradient_checkpointing:
543
+
544
+ def create_custom_forward(module, return_dict=None):
545
+ def custom_forward(*inputs):
546
+ if return_dict is not None:
547
+ return module(*inputs, return_dict=return_dict)
548
+ else:
549
+ return module(*inputs)
550
+
551
+ return custom_forward
552
+
553
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None, ip_hidden_states=None):
554
+ def custom_forward(*inputs):
555
+ if return_dict is not None:
556
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
557
+ else:
558
+ return module(*inputs, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
559
+
560
+ return custom_forward
561
+
562
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
563
+ hidden_states = torch.utils.checkpoint.checkpoint(
564
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states),
565
+ hidden_states,
566
+ encoder_hidden_states,
567
+ )[0]
568
+ else:
569
+ hidden_states = resnet(hidden_states, temb)
570
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
571
+
572
+ if self.upsamplers is not None:
573
+ for upsampler in self.upsamplers:
574
+ hidden_states = upsampler(hidden_states, upsample_size)
575
+
576
+ return hidden_states
577
+
578
+
579
+ class UpBlock3D(nn.Module):
580
+ def __init__(
581
+ self,
582
+ in_channels: int,
583
+ prev_output_channel: int,
584
+ out_channels: int,
585
+ temb_channels: int,
586
+ dropout: float = 0.0,
587
+ num_layers: int = 1,
588
+ resnet_eps: float = 1e-6,
589
+ resnet_time_scale_shift: str = "default",
590
+ resnet_act_fn: str = "swish",
591
+ resnet_groups: int = 32,
592
+ resnet_pre_norm: bool = True,
593
+ output_scale_factor=1.0,
594
+ add_upsample=True,
595
+ ):
596
+ super().__init__()
597
+ resnets = []
598
+
599
+ for i in range(num_layers):
600
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
601
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
602
+
603
+ resnets.append(
604
+ ResnetBlock3D(
605
+ in_channels=resnet_in_channels + res_skip_channels,
606
+ out_channels=out_channels,
607
+ temb_channels=temb_channels,
608
+ eps=resnet_eps,
609
+ groups=resnet_groups,
610
+ dropout=dropout,
611
+ time_embedding_norm=resnet_time_scale_shift,
612
+ non_linearity=resnet_act_fn,
613
+ output_scale_factor=output_scale_factor,
614
+ pre_norm=resnet_pre_norm,
615
+ )
616
+ )
617
+
618
+ self.resnets = nn.ModuleList(resnets)
619
+
620
+ if add_upsample:
621
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
622
+ else:
623
+ self.upsamplers = None
624
+
625
+ self.gradient_checkpointing = False
626
+
627
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
628
+ for resnet in self.resnets:
629
+ # pop res hidden states
630
+ res_hidden_states = res_hidden_states_tuple[-1]
631
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
632
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
633
+
634
+ if self.training and self.gradient_checkpointing:
635
+
636
+ def create_custom_forward(module):
637
+ def custom_forward(*inputs):
638
+ return module(*inputs)
639
+
640
+ return custom_forward
641
+
642
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
643
+ else:
644
+ hidden_states = resnet(hidden_states, temb)
645
+
646
+ if self.upsamplers is not None:
647
+ for upsampler in self.upsamplers:
648
+ hidden_states = upsampler(hidden_states, upsample_size)
649
+
650
+ return hidden_states
models/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+
15
+ import numpy as np
16
+ import torch.nn as nn
17
+
18
+ from einops import repeat
19
+
20
+
21
+ #################################################################################
22
+ # Unet Utils #
23
+ #################################################################################
24
+
25
+ def checkpoint(func, inputs, params, flag):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ """
35
+ if flag:
36
+ args = tuple(inputs) + tuple(params)
37
+ return CheckpointFunction.apply(func, len(inputs), *args)
38
+ else:
39
+ return func(*inputs)
40
+
41
+
42
+ class CheckpointFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, run_function, length, *args):
45
+ ctx.run_function = run_function
46
+ ctx.input_tensors = list(args[:length])
47
+ ctx.input_params = list(args[length:])
48
+
49
+ with torch.no_grad():
50
+ output_tensors = ctx.run_function(*ctx.input_tensors)
51
+ return output_tensors
52
+
53
+ @staticmethod
54
+ def backward(ctx, *output_grads):
55
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
56
+ with torch.enable_grad():
57
+ # Fixes a bug where the first op in run_function modifies the
58
+ # Tensor storage in place, which is not allowed for detach()'d
59
+ # Tensors.
60
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
61
+ output_tensors = ctx.run_function(*shallow_copies)
62
+ input_grads = torch.autograd.grad(
63
+ output_tensors,
64
+ ctx.input_tensors + ctx.input_params,
65
+ output_grads,
66
+ allow_unused=True,
67
+ )
68
+ del ctx.input_tensors
69
+ del ctx.input_params
70
+ del output_tensors
71
+ return (None, None) + input_grads
72
+
73
+
74
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
75
+ """
76
+ Create sinusoidal timestep embeddings.
77
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
78
+ These may be fractional.
79
+ :param dim: the dimension of the output.
80
+ :param max_period: controls the minimum frequency of the embeddings.
81
+ :return: an [N x dim] Tensor of positional embeddings.
82
+ """
83
+ if not repeat_only:
84
+ half = dim // 2
85
+ freqs = torch.exp(
86
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87
+ ).to(device=timesteps.device)
88
+ args = timesteps[:, None].float() * freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
+ else:
93
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
94
+ return embedding
95
+
96
+
97
+ def zero_module(module):
98
+ """
99
+ Zero out the parameters of a module and return it.
100
+ """
101
+ for p in module.parameters():
102
+ p.detach().zero_()
103
+ return module
104
+
105
+
106
+ def scale_module(module, scale):
107
+ """
108
+ Scale the parameters of a module and return it.
109
+ """
110
+ for p in module.parameters():
111
+ p.detach().mul_(scale)
112
+ return module
113
+
114
+
115
+ def mean_flat(tensor):
116
+ """
117
+ Take the mean over all non-batch dimensions.
118
+ """
119
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
120
+
121
+
122
+ def normalization(channels):
123
+ """
124
+ Make a standard normalization layer.
125
+ :param channels: number of input channels.
126
+ :return: an nn.Module for normalization.
127
+ """
128
+ return GroupNorm32(32, channels)
129
+
130
+
131
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
132
+ class SiLU(nn.Module):
133
+ def forward(self, x):
134
+ return x * torch.sigmoid(x)
135
+
136
+
137
+ class GroupNorm32(nn.GroupNorm):
138
+ def forward(self, x):
139
+ return super().forward(x.float()).type(x.dtype)
140
+
141
+ def conv_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D convolution module.
144
+ """
145
+ if dims == 1:
146
+ return nn.Conv1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.Conv2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.Conv3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def linear(*args, **kwargs):
155
+ """
156
+ Create a linear module.
157
+ """
158
+ return nn.Linear(*args, **kwargs)
159
+
160
+
161
+ def avg_pool_nd(dims, *args, **kwargs):
162
+ """
163
+ Create a 1D, 2D, or 3D average pooling module.
164
+ """
165
+ if dims == 1:
166
+ return nn.AvgPool1d(*args, **kwargs)
167
+ elif dims == 2:
168
+ return nn.AvgPool2d(*args, **kwargs)
169
+ elif dims == 3:
170
+ return nn.AvgPool3d(*args, **kwargs)
171
+ raise ValueError(f"unsupported dimensions: {dims}")
172
+
173
+
174
+ # class HybridConditioner(nn.Module):
175
+
176
+ # def __init__(self, c_concat_config, c_crossattn_config):
177
+ # super().__init__()
178
+ # self.concat_conditioner = instantiate_from_config(c_concat_config)
179
+ # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
180
+
181
+ # def forward(self, c_concat, c_crossattn):
182
+ # c_concat = self.concat_conditioner(c_concat)
183
+ # c_crossattn = self.crossattn_conditioner(c_crossattn)
184
+ # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
185
+
186
+
187
+ def noise_like(shape, device, repeat=False):
188
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
189
+ noise = lambda: torch.randn(shape, device=device)
190
+ return repeat_noise() if repeat else noise()
191
+
192
+ def count_flops_attn(model, _x, y):
193
+ """
194
+ A counter for the `thop` package to count the operations in an
195
+ attention operation.
196
+ Meant to be used like:
197
+ macs, params = thop.profile(
198
+ model,
199
+ inputs=(inputs, timestamps),
200
+ custom_ops={QKVAttention: QKVAttention.count_flops},
201
+ )
202
+ """
203
+ b, c, *spatial = y[0].shape
204
+ num_spatial = int(np.prod(spatial))
205
+ # We perform two matmuls with the same number of ops.
206
+ # The first computes the weight matrix, the second computes
207
+ # the combination of the value vectors.
208
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
209
+ model.total_ops += torch.DoubleTensor([matmul_ops])
210
+
211
+ def count_params(model, verbose=False):
212
+ total_params = sum(p.numel() for p in model.parameters())
213
+ if verbose:
214
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
215
+ return total_params
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bark_ssg==1.3.4
2
+ decord==0.6.0
3
+ diffusers==0.25.0
4
+ einops==0.7.0
5
+ imageio==2.28.0
6
+ ipython==8.14.0
7
+ librosa==0.10.1
8
+ mmcv==2.1.0
9
+ moviepy==1.0.3
10
+ natsort==8.3.1
11
+ nltk==3.8.1
12
+ numpy==1.23.5
13
+ omegaconf==2.3.0
14
+ openai==0.27.8
15
+ opencv_python==4.7.0.72
16
+ Pillow==9.4.0
17
+ Pillow==10.2.0
18
+ pytorch_lightning==2.0.2
19
+ rotary_embedding_torch==0.2.3
20
+ soundfile==0.12.1
21
+ torch==2.0.0
22
+ torchvision==0.15.0
23
+ tqdm==4.65.0
24
+ transformers==4.28.1
25
+ xformers==0.0.19
results/mask_no_ref/Planet_hits_earth..mp4 ADDED
Binary file (326 kB). View file
 
results/mask_ref/Planet_hits_earth..mp4 ADDED
Binary file (345 kB). View file
 
results/vlog/teddy_travel/ref_img/teddy.jpg ADDED
results/vlog/teddy_travel/script/protagonist_place_reference.txt ADDED
Binary file (1.53 kB). View file
 
results/vlog/teddy_travel/script/protagonists_places.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": 1,
4
+ "name": "Teddy",
5
+ "description": "A teddy bear with a dream of traveling the world"
6
+ },
7
+ {
8
+ "id": 2,
9
+ "name": "Eiffel Tower",
10
+ "description": "An iconic wrought-iron lattice tower located in Paris, France"
11
+ },
12
+ {
13
+ "id": 3,
14
+ "name": "Great Wall",
15
+ "description": "A vast, historic fortification system that stretches across the northern part of China"
16
+ },
17
+ {
18
+ "id": 4,
19
+ "name": "Pyramids",
20
+ "description": "Ancient monumental structures located in Egypt"
21
+ }
22
+ ]
results/vlog/teddy_travel/script/time_scripts.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "video fragment id": 1,
4
+ "time": 2
5
+ },
6
+ {
7
+ "video fragment id": 2,
8
+ "time": 3
9
+ },
10
+ {
11
+ "video fragment id": 3,
12
+ "time": 3
13
+ },
14
+ {
15
+ "video fragment id": 4,
16
+ "time": 2
17
+ },
18
+ {
19
+ "video fragment id": 5,
20
+ "time": 2
21
+ },
22
+ {
23
+ "video fragment id": 6,
24
+ "time": 3
25
+ },
26
+ {
27
+ "video fragment id": 7,
28
+ "time": 2
29
+ },
30
+ {
31
+ "video fragment id": 8,
32
+ "time": 3
33
+ },
34
+ {
35
+ "video fragment id": 9,
36
+ "time": 2
37
+ },
38
+ {
39
+ "video fragment id": 10,
40
+ "time": 2
41
+ },
42
+ {
43
+ "video fragment id": 11,
44
+ "time": 3
45
+ },
46
+ {
47
+ "video fragment id": 12,
48
+ "time": 2
49
+ },
50
+ {
51
+ "video fragment id": 13,
52
+ "time": 2
53
+ },
54
+ {
55
+ "video fragment id": 14,
56
+ "time": 3
57
+ },
58
+ {
59
+ "video fragment id": 15,
60
+ "time": 3
61
+ },
62
+ {
63
+ "video fragment id": 16,
64
+ "time": 2
65
+ },
66
+ {
67
+ "video fragment id": 17,
68
+ "time": 3
69
+ },
70
+ {
71
+ "video fragment id": 18,
72
+ "time": 2
73
+ },
74
+ {
75
+ "video fragment id": 19,
76
+ "time": 3
77
+ },
78
+ {
79
+ "video fragment id": 20,
80
+ "time": 2
81
+ },
82
+ {
83
+ "video fragment id": 21,
84
+ "time": 3
85
+ },
86
+ {
87
+ "video fragment id": 22,
88
+ "time": 2
89
+ },
90
+ {
91
+ "video fragment id": 23,
92
+ "time": 3
93
+ }
94
+ ]
results/vlog/teddy_travel/script/video_prompts.txt ADDED
Binary file (2.61 kB). View file
 
results/vlog/teddy_travel/script/zh_video_prompts.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "序号": 1,
4
+ "描述": "泰迪熊在孩子的房间里。",
5
+ },
6
+ {
7
+ "序号": 2,
8
+ "描述": "泰迪熊正在做梦。",
9
+ },
10
+ {
11
+ "序号": 3,
12
+ "描述": "梦想着旅行。",
13
+ },
14
+ {
15
+ "序号": 4,
16
+ "描述": "泰迪熊在机场。",
17
+ },
18
+ {
19
+ "序号": 5,
20
+ "描述": "泰迪熊从背包中探出头来。",
21
+ },
22
+ {
23
+ "序号": 6,
24
+ "描述": "泰迪熊在野餐毯上。",
25
+ },
26
+ {
27
+ "序号": 7,
28
+ "描述": "背景是埃菲尔铁塔。",
29
+ },
30
+ {
31
+ "序号": 8,
32
+ "描述": "泰迪熊正在享受巴黎野餐。",
33
+ },
34
+ {
35
+ "序号": 9,
36
+ "描述": "泰迪熊周围是羊角面包。",
37
+ },
38
+ {
39
+ "序号": 10,
40
+ "描述": "泰迪熊在长城顶部。",
41
+ },
42
+ {
43
+ "序号": 11,
44
+ "描述": "泰迪熊正在欣赏风景。",
45
+ },
46
+ {
47
+ "序号": 12,
48
+ "描述": "泰迪熊在埃及探索金字塔。",
49
+ },
50
+ {
51
+ "序号": 13,
52
+ "描述": "炎热的埃及阳光下。",
53
+ },
54
+ {
55
+ "序号": 14,
56
+ "描述": "泰迪熊找到了一个宝箱。",
57
+ },
58
+ {
59
+ "序号": 15,
60
+ "描述": "宝箱在金字塔内部。",
61
+ },
62
+ {
63
+ "序号": 16,
64
+ "描述": "泰迪熊回到卧室。",
65
+ },
66
+ {
67
+ "序号": 17,
68
+ "描述": "分享旅行故事。",
69
+ },
70
+ {
71
+ "序号": 18,
72
+ "描述": "一个小女孩在反应。",
73
+ },
74
+ {
75
+ "序号": 19,
76
+ "描述": "惊讶于泰迪熊的故事。",
77
+ },
78
+ {
79
+ "序号": 20,
80
+ "描述": "房间里满是纪念品。",
81
+ },
82
+ {
83
+ "序号": 21,
84
+ "描述": "来自泰迪熊旅行的纪念品。",
85
+ },
86
+ {
87
+ "序号": 22,
88
+ "描述": "泰迪熊正在看世界地图。",
89
+ },
90
+ {
91
+ "序号": 23,
92
+ "描述": "梦想着下一次的冒险。",
93
+ }
94
+
95
+ ]
results/vlog/teddy_travel/story.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Once upon a time, there was a teddy bear named Teddy who dreamed of traveling the world. One day, his dream came true to travel around the world. Teddy sat in the airport lobby and traveled to many places of interest. Along the way, Teddy visited the Eiffel Tower, the Great Wall, and the pyramids. In Paris, Teddy had a picnic and enjoyed some delicious croissants. At the Great Wall of China, he climbed to the top and marveled at the breathtaking view. And in Egypt, he explored the pyramids and even found a secret treasure hidden inside. After his exciting journey, Teddy was eventually reunited with his owner who was thrilled to hear about all of Teddy’s adventures. From that day on, Teddy always dreamed of traveling the world again and experiencing new and exciting things.
results/vlog/teddy_travel_/story.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Once upon a time, there was a teddy bear named Teddy who dreamed of traveling the world. One day, his dream came true to travel around the world. Teddy sat in the airport lobby and traveled to many places of interest. Along the way, Teddy visited the Eiffel Tower, the Great Wall, and the pyramids. In Paris, Teddy had a picnic and enjoyed some delicious croissants. At the Great Wall of China, he climbed to the top and marveled at the breathtaking view. And in Egypt, he explored the pyramids and even found a secret treasure hidden inside. After his exciting journey, Teddy was eventually reunited with his owner who was thrilled to hear about all of Teddy’s adventures. From that day on, Teddy always dreamed of traveling the world again and experiencing new and exciting things.
sample_scripts/vlog_read_script_sample.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ torch.backends.cuda.matmul.allow_tf32 = True
4
+ torch.backends.cudnn.allow_tf32 = True
5
+ import os
6
+ import sys
7
+ try:
8
+ import utils
9
+ from diffusion import create_diffusion
10
+ except:
11
+ sys.path.append(os.path.split(sys.path[0])[0])
12
+ import utils
13
+ from diffusion import create_diffusion
14
+ import argparse
15
+ import torchvision
16
+ from PIL import Image
17
+ from einops import rearrange
18
+ from models import get_models
19
+ from diffusers.models import AutoencoderKL
20
+ from models.clip import TextEmbedder
21
+ from omegaconf import OmegaConf
22
+ from pytorch_lightning import seed_everything
23
+ from utils import mask_generation_before
24
+ from diffusers.utils.import_utils import is_xformers_available
25
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
26
+ from vlogger.videofusion import fusion
27
+ from vlogger.videocaption import captioning
28
+ from vlogger.videoaudio import make_audio, merge_video_audio, concatenate_videos
29
+ from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model
30
+ from vlogger.planning_utils.gpt4_utils import (readscript,
31
+ readtimescript,
32
+ readprotagonistscript,
33
+ readreferencescript,
34
+ readzhscript)
35
+
36
+
37
+ def auto_inpainting(args,
38
+ video_input,
39
+ masked_video,
40
+ mask,
41
+ prompt,
42
+ image,
43
+ vae,
44
+ text_encoder,
45
+ image_encoder,
46
+ diffusion,
47
+ model,
48
+ device,
49
+ ):
50
+ image_prompt_embeds = None
51
+ if prompt is None:
52
+ prompt = ""
53
+ if image is not None:
54
+ clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values
55
+ clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
56
+ uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
57
+ image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
58
+ image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
59
+ model = ip_scale_set(model, args.ref_cfg_scale)
60
+ if args.use_fp16:
61
+ image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
62
+ b, f, c, h, w = video_input.shape
63
+ latent_h = video_input.shape[-2] // 8
64
+ latent_w = video_input.shape[-1] // 8
65
+
66
+ if args.use_fp16:
67
+ z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
68
+ masked_video = masked_video.to(dtype=torch.float16)
69
+ mask = mask.to(dtype=torch.float16)
70
+ else:
71
+ z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
72
+
73
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
74
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
75
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
76
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
77
+ masked_video = torch.cat([masked_video] * 2)
78
+ mask = torch.cat([mask] * 2)
79
+ z = torch.cat([z] * 2)
80
+ prompt_all = [prompt] + [args.negative_prompt]
81
+
82
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
83
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
84
+ class_labels=None,
85
+ cfg_scale=args.cfg_scale,
86
+ use_fp16=args.use_fp16,
87
+ ip_hidden_states=image_prompt_embeds)
88
+
89
+ # Sample images:
90
+ samples = diffusion.ddim_sample_loop(model.forward_with_cfg,
91
+ z.shape,
92
+ z,
93
+ clip_denoised=False,
94
+ model_kwargs=model_kwargs,
95
+ progress=True,
96
+ device=device,
97
+ mask=mask,
98
+ x_start=masked_video,
99
+ use_concat=True,
100
+ )
101
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
102
+ if args.use_fp16:
103
+ samples = samples.to(dtype=torch.float16)
104
+
105
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
106
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
107
+ return video_clip
108
+
109
+
110
+ def main(args):
111
+ # Setup PyTorch:
112
+ if args.seed:
113
+ torch.manual_seed(args.seed)
114
+ torch.set_grad_enabled(False)
115
+ device = "cuda" if torch.cuda.is_available() else "cpu"
116
+ seed_everything(args.seed)
117
+
118
+ model = get_models(args).to(device)
119
+ model = tca_transform_model(model).to(device)
120
+ model = ip_transform_model(model).to(device)
121
+ if args.enable_xformers_memory_efficient_attention:
122
+ if is_xformers_available():
123
+ model.enable_xformers_memory_efficient_attention()
124
+ else:
125
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
126
+ if args.use_compile:
127
+ model = torch.compile(model)
128
+
129
+ ckpt_path = args.ckpt
130
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
131
+ model_dict = model.state_dict()
132
+ pretrained_dict = {}
133
+ for k, v in state_dict.items():
134
+ if k in model_dict:
135
+ pretrained_dict[k] = v
136
+ model_dict.update(pretrained_dict)
137
+ model.load_state_dict(model_dict)
138
+
139
+ model.eval() # important!
140
+ diffusion = create_diffusion(str(args.num_sampling_steps))
141
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
142
+ text_encoder = text_encoder = TextEmbedder(args.pretrained_model_path).to(device)
143
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
144
+ if args.use_fp16:
145
+ print('Warnning: using half percision for inferencing!')
146
+ vae.to(dtype=torch.float16)
147
+ model.to(dtype=torch.float16)
148
+ text_encoder.to(dtype=torch.float16)
149
+ print("model ready!\n", flush=True)
150
+
151
+
152
+ # load protagonist script
153
+ character_places = readprotagonistscript(args.protagonist_file_path)
154
+ print("protagonists ready!", flush=True)
155
+
156
+ # load script
157
+ video_list = readscript(args.script_file_path)
158
+ print("video script ready!", flush=True)
159
+
160
+ # load reference script
161
+ reference_lists = readreferencescript(video_list, character_places, args.reference_file_path)
162
+ print("reference script ready!", flush=True)
163
+
164
+ # load zh script
165
+ zh_video_list = readzhscript(args.zh_script_file_path)
166
+ print("zh script ready!", flush=True)
167
+
168
+ # load time script
169
+ key_list = []
170
+ for key, value in character_places.items():
171
+ key_list.append(key)
172
+ time_list = readtimescript(args.time_file_path)
173
+ print("time script ready!", flush=True)
174
+
175
+
176
+ # generation begin
177
+ sample_list = []
178
+ for i, text_prompt in enumerate(video_list):
179
+ sample_list.append([])
180
+ for time in range(time_list[i]):
181
+ if time == 0:
182
+ print('Generating the ({}) prompt'.format(text_prompt), flush=True)
183
+ if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
184
+ pil_image = None
185
+ else:
186
+ pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
187
+ pil_image.resize((256, 256))
188
+ video_input = torch.zeros([1, 16, 3, args.image_size[0], args.image_size[1]]).to(device)
189
+ mask = mask_generation_before("first0", video_input.shape, video_input.dtype, device) # b,f,c,h,w
190
+ masked_video = video_input * (mask == 0)
191
+ samples = auto_inpainting(args,
192
+ video_input,
193
+ masked_video,
194
+ mask,
195
+ text_prompt,
196
+ pil_image,
197
+ vae,
198
+ text_encoder,
199
+ image_encoder,
200
+ diffusion,
201
+ model,
202
+ device,
203
+ )
204
+ sample_list[i].append(samples)
205
+ else:
206
+ if sum(video.shape[0] for video in sample_list[i]) / args.fps >= time_list[i]:
207
+ break
208
+ print('Generating the ({}) prompt'.format(text_prompt), flush=True)
209
+ if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
210
+ pil_image = None
211
+ else:
212
+ pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
213
+ pil_image.resize((256, 256))
214
+ pre_video = sample_list[i][-1][-args.researve_frame:]
215
+ f, c, h, w = pre_video.shape
216
+ lat_video = torch.zeros(args.num_frames - args.researve_frame, c, h, w).to(device)
217
+ video_input = torch.concat([pre_video, lat_video], dim=0)
218
+ video_input = video_input.to(device).unsqueeze(0)
219
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
220
+ masked_video = video_input * (mask == 0)
221
+ video_clip = auto_inpainting(args,
222
+ video_input,
223
+ masked_video,
224
+ mask,
225
+ text_prompt,
226
+ pil_image,
227
+ vae,
228
+ text_encoder,
229
+ image_encoder,
230
+ diffusion,
231
+ model,
232
+ device,
233
+ )
234
+ sample_list[i].append(video_clip[args.researve_frame:])
235
+ print(video_clip[args.researve_frame:].shape)
236
+
237
+ # transition
238
+ if args.video_transition and i != 0:
239
+ video_1 = sample_list[i - 1][-1][-1:]
240
+ video_2 = sample_list[i][0][:1]
241
+ f, c, h, w = video_1.shape
242
+ video_middle = torch.zeros(args.num_frames - 2, c, h, w).to(device)
243
+ video_input = torch.concat([video_1, video_middle, video_2], dim=0)
244
+ video_input = video_input.to(device).unsqueeze(0)
245
+ mask = mask_generation_before("onelast1", video_input.shape, video_input.dtype, device)
246
+ masked_video = masked_video = video_input * (mask == 0)
247
+ video_clip = auto_inpainting(args,
248
+ video_input,
249
+ masked_video,
250
+ mask,
251
+ "smooth transition, slow motion, slow changing.",
252
+ pil_image,
253
+ vae,
254
+ text_encoder,
255
+ image_encoder,
256
+ diffusion,
257
+ model,
258
+ device,
259
+ )
260
+ sample_list[i].insert(0, video_clip[1:-1])
261
+
262
+ # save videos
263
+ samples = torch.concat(sample_list[i], dim=0)
264
+ samples = samples[0: time_list[i] * args.fps]
265
+ if not os.path.exists(args.save_origin_video_path):
266
+ os.makedirs(args.save_origin_video_path)
267
+ video_ = ((samples * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
268
+ torchvision.io.write_video(args.save_origin_video_path + "/" + f"{i}" + '.mp4', video_, fps=args.fps)
269
+
270
+ # post processing
271
+ fusion(args.save_origin_video_path)
272
+ captioning(args.script_file_path, args.zh_script_file_path, args.save_origin_video_path, args.save_caption_video_path)
273
+ fusion(args.save_caption_video_path)
274
+ make_audio(args.script_file_path, args.save_audio_path)
275
+ merge_video_audio(args.save_caption_video_path, args.save_audio_path, args.save_audio_caption_video_path)
276
+ concatenate_videos(args.save_audio_caption_video_path)
277
+ print('final video save path {}'.format(args.save_audio_caption_video_path))
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument("--config", type=str, default="configs/vlog_read_script_sample.yaml")
283
+ args = parser.parse_args()
284
+ omega_conf = OmegaConf.load(args.config)
285
+ save_path = omega_conf.save_path
286
+ save_origin_video_path = os.path.join(save_path, "origin_video")
287
+ save_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "caption_video")
288
+ save_audio_path = os.path.join(save_path.rsplit('/', 1)[0], "audio")
289
+ save_audio_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "audio_caption_video")
290
+ if omega_conf.sample_num is not None:
291
+ for i in range(omega_conf.sample_num):
292
+ omega_conf.save_origin_video_path = save_origin_video_path + f'-{i}'
293
+ omega_conf.save_caption_video_path = save_caption_video_path + f'-{i}'
294
+ omega_conf.save_audio_path = save_audio_path + f'-{i}'
295
+ omega_conf.save_audio_caption_video_path = save_audio_caption_video_path + f'-{i}'
296
+ omega_conf.seed += i
297
+ main(omega_conf)
298
+ else:
299
+ omega_conf.save_origin_video_path = save_origin_video_path
300
+ omega_conf.save_caption_video_path = save_caption_video_path
301
+ omega_conf.save_audio_path = save_audio_path
302
+ omega_conf.save_audio_caption_video_path = save_audio_caption_video_path
303
+ main(omega_conf)
sample_scripts/vlog_write_script.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ os.environ['CURL_CA_BUNDLE'] = ''
4
+ import argparse
5
+ from omegaconf import OmegaConf
6
+ from diffusers import DiffusionPipeline
7
+ from vlogger.planning_utils.gpt4_utils import (ExtractProtagonist,
8
+ ExtractAProtagonist,
9
+ split_story,
10
+ patch_story_scripts,
11
+ refine_story_scripts,
12
+ protagonist_place_reference1,
13
+ translate_video_script,
14
+ time_scripts,
15
+ )
16
+
17
+
18
+ def main(args):
19
+ story_path = args.story_path
20
+ save_script_path = os.path.join(story_path.rsplit('/', 1)[0], "script")
21
+ if not os.path.exists(save_script_path):
22
+ os.makedirs(save_script_path)
23
+ with open(story_path, "r") as story_file:
24
+ story = story_file.read()
25
+
26
+ # summerize protagonists and places
27
+ protagonists_places_file_path = os.path.join(save_script_path, "protagonists_places.txt")
28
+ if args.only_one_protagonist:
29
+ character_places = ExtractAProtagonist(story, protagonists_places_file_path)
30
+ else:
31
+ character_places = ExtractProtagonist(story, protagonists_places_file_path)
32
+ print("Protagonists and places OK", flush=True)
33
+
34
+ # make script
35
+ script_file_path = os.path.join(save_script_path, "video_prompts.txt")
36
+ video_list = split_story(story, script_file_path)
37
+ video_list = patch_story_scripts(story, video_list, script_file_path)
38
+ video_list = refine_story_scripts(video_list, script_file_path)
39
+ print("Scripts OK", flush=True)
40
+
41
+ # think about the protagonist in each scene
42
+ reference_file_path = os.path.join(save_script_path, "protagonist_place_reference.txt")
43
+ reference_lists = protagonist_place_reference1(video_list, character_places, reference_file_path)
44
+ print("Reference protagonist OK", flush=True)
45
+
46
+ # translate the English script to Chinese
47
+ zh_file_path = os.path.join(save_script_path, "zh_video_prompts.txt")
48
+ zh_video_list = translate_video_script(video_list, zh_file_path)
49
+ print("Translation OK", flush=True)
50
+
51
+ # schedule the time of script
52
+ time_file_path = os.path.join(save_script_path, "time_scripts.txt")
53
+ time_list = time_scripts(video_list, time_file_path)
54
+ print("Time script OK", flush=True)
55
+
56
+ # make reference image
57
+ base = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
58
+ torch_dtype=torch.float16,
59
+ variant="fp16",
60
+ use_safetensors=True,
61
+ ).to("cuda")
62
+ refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0",
63
+ text_encoder_2=base.text_encoder_2,
64
+ vae=base.vae,
65
+ torch_dtype=torch.float16,
66
+ use_safetensors=True,
67
+ variant="fp16",
68
+ ).to("cuda")
69
+ ref_dir_path = os.path.join(story_path.rsplit('/', 1)[0], "ref_img")
70
+ if not os.path.exists(ref_dir_path):
71
+ os.makedirs(ref_dir_path)
72
+ for key, value in character_places.items():
73
+ prompt = key + ", " + value
74
+ img_path = os.path.join(ref_dir_path, key + ".jpg")
75
+ image = base(prompt=prompt,
76
+ output_type="latent",
77
+ height=1024,
78
+ width=1024,
79
+ guidance_scale=7
80
+ ).images[0]
81
+ image = refiner(prompt=prompt, image=image[None, :]).images[0]
82
+ image.save(img_path)
83
+ print("Reference image OK", flush=True)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--config", type=str, default="configs/vlog_write_script.yaml")
89
+ args = parser.parse_args()
90
+ omega_conf = OmegaConf.load(args.config)
91
+ main(omega_conf)
sample_scripts/with_mask_ref_sample.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Sample new images from a pre-trained DiT.
9
+ """
10
+ import os
11
+ import sys
12
+ import math
13
+ try:
14
+ import utils
15
+ from diffusion import create_diffusion
16
+ except:
17
+ # sys.path.append(os.getcwd())
18
+ sys.path.append(os.path.split(sys.path[0])[0])
19
+ # sys.path[0]
20
+ # os.path.split(sys.path[0])
21
+ import utils
22
+
23
+ from diffusion import create_diffusion
24
+
25
+ import torch
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.allow_tf32 = True
28
+ import argparse
29
+ import torchvision
30
+
31
+ from einops import rearrange
32
+ from models import get_models
33
+ from torchvision.utils import save_image
34
+ from diffusers.models import AutoencoderKL
35
+ from models.clip import TextEmbedder
36
+ from omegaconf import OmegaConf
37
+ from PIL import Image
38
+ import numpy as np
39
+ from torchvision import transforms
40
+ sys.path.append("..")
41
+ from datasets import video_transforms
42
+ from utils import mask_generation_before
43
+ from natsort import natsorted
44
+ from diffusers.utils.import_utils import is_xformers_available
45
+ from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model
46
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
47
+
48
+ def get_input(args):
49
+ input_path = args.input_path
50
+ transform_video = transforms.Compose([
51
+ video_transforms.ToTensorVideo(), # TCHW
52
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
53
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
54
+ ])
55
+ if input_path is not None:
56
+ print(f'loading video from {input_path}')
57
+ if os.path.isdir(input_path):
58
+ file_list = os.listdir(input_path)
59
+ video_frames = []
60
+ if args.mask_type.startswith('onelast'):
61
+ num = int(args.mask_type.split('onelast')[-1])
62
+ # get first and last frame
63
+ first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
64
+ last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
65
+ first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
66
+ last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
67
+ for i in range(num):
68
+ video_frames.append(first_frame)
69
+ # add zeros to frames
70
+ num_zeros = args.num_frames-2*num
71
+ for i in range(num_zeros):
72
+ zeros = torch.zeros_like(first_frame)
73
+ video_frames.append(zeros)
74
+ for i in range(num):
75
+ video_frames.append(last_frame)
76
+ n = 0
77
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
78
+ video_frames = transform_video(video_frames)
79
+ else:
80
+ for file in file_list:
81
+ if file.endswith('jpg') or file.endswith('png'):
82
+ image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0)
83
+ video_frames.append(image)
84
+ else:
85
+ continue
86
+ n = 0
87
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
88
+ video_frames = transform_video(video_frames)
89
+ return video_frames, n
90
+ elif os.path.isfile(input_path):
91
+ _, full_file_name = os.path.split(input_path)
92
+ file_name, extention = os.path.splitext(full_file_name)
93
+ if extention == '.jpg' or extention == '.png':
94
+ print("loading the input image")
95
+ video_frames = []
96
+ num = int(args.mask_type.split('first')[-1])
97
+ first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
98
+ for i in range(num):
99
+ video_frames.append(first_frame)
100
+ num_zeros = args.num_frames-num
101
+ for i in range(num_zeros):
102
+ zeros = torch.zeros_like(first_frame)
103
+ video_frames.append(zeros)
104
+ n = 0
105
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
106
+ video_frames = transform_video(video_frames)
107
+ return video_frames, n
108
+ else:
109
+ raise TypeError(f'{extention} is not supported !!')
110
+ else:
111
+ raise ValueError('Please check your path input!!')
112
+ else:
113
+ raise ValueError('Need to give a video or some images')
114
+
115
+ def auto_inpainting(args,
116
+ video_input,
117
+ masked_video,
118
+ mask,
119
+ prompt,
120
+ image,
121
+ vae,
122
+ text_encoder,
123
+ image_encoder,
124
+ diffusion,
125
+ model,
126
+ device,
127
+ ):
128
+ image_prompt_embeds = None
129
+ if prompt is None:
130
+ prompt = ""
131
+ if image is not None:
132
+ clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values
133
+ clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
134
+ uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
135
+ image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
136
+ image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
137
+ model = ip_scale_set(model, args.ref_cfg_scale)
138
+ if args.use_fp16:
139
+ image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
140
+ b, f, c, h, w = video_input.shape
141
+ latent_h = video_input.shape[-2] // 8
142
+ latent_w = video_input.shape[-1] // 8
143
+
144
+ if args.use_fp16:
145
+ z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
146
+ masked_video = masked_video.to(dtype=torch.float16)
147
+ mask = mask.to(dtype=torch.float16)
148
+ else:
149
+ z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
150
+
151
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
152
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
153
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
154
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
155
+ masked_video = torch.cat([masked_video] * 2)
156
+ mask = torch.cat([mask] * 2)
157
+ z = torch.cat([z] * 2)
158
+ prompt_all = [prompt] + [args.negative_prompt]
159
+
160
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
161
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
162
+ class_labels=None,
163
+ cfg_scale=args.cfg_scale,
164
+ use_fp16=args.use_fp16,
165
+ ip_hidden_states=image_prompt_embeds)
166
+
167
+ # Sample images:
168
+ samples = diffusion.ddim_sample_loop(
169
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
170
+ mask=mask, x_start=masked_video, use_concat=True
171
+ )
172
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
173
+ if args.use_fp16:
174
+ samples = samples.to(dtype=torch.float16)
175
+
176
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
177
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
178
+ return video_clip
179
+
180
+ def main(args):
181
+ # Setup PyTorch:
182
+ if args.seed:
183
+ torch.manual_seed(args.seed)
184
+ torch.set_grad_enabled(False)
185
+ device = "cuda" if torch.cuda.is_available() else "cpu"
186
+ # device = "cpu"
187
+
188
+ if args.ckpt is None:
189
+ raise ValueError("Please specify a checkpoint path using --ckpt <path>")
190
+
191
+ # Load model:
192
+ latent_h = args.image_size[0] // 8
193
+ latent_w = args.image_size[1] // 8
194
+ args.image_h = args.image_size[0]
195
+ args.image_w = args.image_size[1]
196
+ args.latent_h = latent_h
197
+ args.latent_w = latent_w
198
+ print('loading model')
199
+ model = get_models(args).to(device)
200
+ model = tca_transform_model(model).to(device)
201
+ model = ip_transform_model(model).to(device)
202
+
203
+ if args.enable_xformers_memory_efficient_attention:
204
+ if is_xformers_available():
205
+ model.enable_xformers_memory_efficient_attention()
206
+ else:
207
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
208
+
209
+ # load model
210
+ ckpt_path = args.ckpt
211
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
212
+ model_dict = model.state_dict()
213
+ pretrained_dict = {}
214
+ for k, v in state_dict.items():
215
+ if k in model_dict:
216
+ pretrained_dict[k] = v
217
+ model_dict.update(pretrained_dict)
218
+ model.load_state_dict(model_dict)
219
+
220
+ model.eval()
221
+ pretrained_model_path = args.pretrained_model_path
222
+ diffusion = create_diffusion(str(args.num_sampling_steps))
223
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
224
+ text_encoder = TextEmbedder(pretrained_model_path).to(device)
225
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
226
+ if args.use_fp16:
227
+ print('Warnning: using half percision for inferencing!')
228
+ vae.to(dtype=torch.float16)
229
+ model.to(dtype=torch.float16)
230
+ text_encoder.to(dtype=torch.float16)
231
+
232
+ # prompt:
233
+ prompt = args.text_prompt
234
+ if prompt ==[]:
235
+ prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
236
+ else:
237
+ prompt = prompt[0]
238
+ prompt_base = prompt.replace(' ','_')
239
+ prompt = prompt + args.additional_prompt
240
+
241
+ if not os.path.exists(os.path.join(args.save_path)):
242
+ os.makedirs(os.path.join(args.save_path))
243
+ video_input, researve_frames = get_input(args) # f,c,h,w
244
+ video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
245
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
246
+ masked_video = video_input * (mask == 0)
247
+
248
+ pil_image = Image.open(args.ref_path)
249
+ pil_image.resize((256, 256))
250
+
251
+ video_clip = auto_inpainting(args,
252
+ video_input,
253
+ masked_video,
254
+ mask,
255
+ prompt,
256
+ pil_image,
257
+ vae,
258
+ text_encoder,
259
+ image_encoder,
260
+ diffusion,
261
+ model,
262
+ device,
263
+ )
264
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
265
+ save_video_path = os.path.join(args.save_path, prompt_base+ '.mp4')
266
+ torchvision.io.write_video(save_video_path, video_, fps=8)
267
+ print(f'save in {save_video_path}')
268
+
269
+
270
+ if __name__ == "__main__":
271
+ parser = argparse.ArgumentParser()
272
+ parser.add_argument("--config", type=str, default="configs/with_mask_ref_sample.yaml")
273
+ args = parser.parse_args()
274
+ omega_conf = OmegaConf.load(args.config)
275
+ main(omega_conf)
sample_scripts/with_mask_sample.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Sample new images from a pre-trained DiT.
9
+ """
10
+ import os
11
+ import sys
12
+ import math
13
+ try:
14
+ import utils
15
+ from diffusion import create_diffusion
16
+ except:
17
+ # sys.path.append(os.getcwd())
18
+ sys.path.append(os.path.split(sys.path[0])[0])
19
+ # sys.path[0]
20
+ # os.path.split(sys.path[0])
21
+ import utils
22
+
23
+ from diffusion import create_diffusion
24
+
25
+ import torch
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.allow_tf32 = True
28
+ import argparse
29
+ import torchvision
30
+
31
+ from einops import rearrange
32
+ from models import get_models
33
+ from torchvision.utils import save_image
34
+ from diffusers.models import AutoencoderKL
35
+ from models.clip import TextEmbedder
36
+ from omegaconf import OmegaConf
37
+ from PIL import Image
38
+ import numpy as np
39
+ from torchvision import transforms
40
+ sys.path.append("..")
41
+ from datasets import video_transforms
42
+ from utils import mask_generation_before
43
+ from natsort import natsorted
44
+ from diffusers.utils.import_utils import is_xformers_available
45
+ from vlogger.STEB.model_transform import tca_transform_model
46
+
47
+ def get_input(args):
48
+ input_path = args.input_path
49
+ transform_video = transforms.Compose([
50
+ video_transforms.ToTensorVideo(), # TCHW
51
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
52
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
53
+ ])
54
+ if input_path is not None:
55
+ print(f'loading video from {input_path}')
56
+ if os.path.isdir(input_path):
57
+ file_list = os.listdir(input_path)
58
+ video_frames = []
59
+ if args.mask_type.startswith('onelast'):
60
+ num = int(args.mask_type.split('onelast')[-1])
61
+ # get first and last frame
62
+ first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
63
+ last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
64
+ first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
65
+ last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
66
+ for i in range(num):
67
+ video_frames.append(first_frame)
68
+ # add zeros to frames
69
+ num_zeros = args.num_frames-2*num
70
+ for i in range(num_zeros):
71
+ zeros = torch.zeros_like(first_frame)
72
+ video_frames.append(zeros)
73
+ for i in range(num):
74
+ video_frames.append(last_frame)
75
+ n = 0
76
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
77
+ video_frames = transform_video(video_frames)
78
+ else:
79
+ for file in file_list:
80
+ if file.endswith('jpg') or file.endswith('png'):
81
+ image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0)
82
+ video_frames.append(image)
83
+ else:
84
+ continue
85
+ n = 0
86
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
87
+ video_frames = transform_video(video_frames)
88
+ return video_frames, n
89
+ elif os.path.isfile(input_path):
90
+ _, full_file_name = os.path.split(input_path)
91
+ file_name, extention = os.path.splitext(full_file_name)
92
+ if extention == '.jpg' or extention == '.png':
93
+ print("loading the input image")
94
+ video_frames = []
95
+ num = int(args.mask_type.split('first')[-1])
96
+ first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
97
+ for i in range(num):
98
+ video_frames.append(first_frame)
99
+ num_zeros = args.num_frames-num
100
+ for i in range(num_zeros):
101
+ zeros = torch.zeros_like(first_frame)
102
+ video_frames.append(zeros)
103
+ n = 0
104
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
105
+ video_frames = transform_video(video_frames)
106
+ return video_frames, n
107
+ else:
108
+ raise TypeError(f'{extention} is not supported !!')
109
+ else:
110
+ raise ValueError('Please check your path input!!')
111
+ else:
112
+ raise ValueError('Need to give a video or some images')
113
+
114
+ def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
115
+ b,f,c,h,w=video_input.shape
116
+ latent_h = args.image_size[0] // 8
117
+ latent_w = args.image_size[1] // 8
118
+
119
+ # prepare inputs
120
+ if args.use_fp16:
121
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
122
+ masked_video = masked_video.to(dtype=torch.float16)
123
+ mask = mask.to(dtype=torch.float16)
124
+ else:
125
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
126
+
127
+
128
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
129
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
130
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
131
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
132
+
133
+ # classifier_free_guidance
134
+ if args.do_classifier_free_guidance:
135
+ masked_video = torch.cat([masked_video] * 2)
136
+ mask = torch.cat([mask] * 2)
137
+ z = torch.cat([z] * 2)
138
+ prompt_all = [prompt] + [args.negative_prompt]
139
+
140
+ else:
141
+ masked_video = masked_video
142
+ mask = mask
143
+ z = z
144
+ prompt_all = [prompt]
145
+
146
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
147
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
148
+ class_labels=None,
149
+ cfg_scale=args.cfg_scale,
150
+ use_fp16=args.use_fp16,) # tav unet
151
+
152
+ # Sample video:
153
+ if args.sample_method == 'ddim':
154
+ samples = diffusion.ddim_sample_loop(
155
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
156
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
157
+ )
158
+ elif args.sample_method == 'ddpm':
159
+ samples = diffusion.p_sample_loop(
160
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
161
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
162
+ )
163
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
164
+ if args.use_fp16:
165
+ samples = samples.to(dtype=torch.float16)
166
+
167
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
168
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
169
+ return video_clip
170
+
171
+ def main(args):
172
+ # Setup PyTorch:
173
+ if args.seed:
174
+ torch.manual_seed(args.seed)
175
+ torch.set_grad_enabled(False)
176
+ device = "cuda" if torch.cuda.is_available() else "cpu"
177
+ # device = "cpu"
178
+
179
+ if args.ckpt is None:
180
+ raise ValueError("Please specify a checkpoint path using --ckpt <path>")
181
+
182
+ # Load model:
183
+ latent_h = args.image_size[0] // 8
184
+ latent_w = args.image_size[1] // 8
185
+ args.image_h = args.image_size[0]
186
+ args.image_w = args.image_size[1]
187
+ args.latent_h = latent_h
188
+ args.latent_w = latent_w
189
+ print('loading model')
190
+ model = get_models(args).to(device)
191
+ model = tca_transform_model(model).to(device)
192
+
193
+ if args.enable_xformers_memory_efficient_attention:
194
+ if is_xformers_available():
195
+ model.enable_xformers_memory_efficient_attention()
196
+ else:
197
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
198
+
199
+ # load model
200
+ ckpt_path = args.ckpt
201
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
202
+ model_dict = model.state_dict()
203
+ pretrained_dict = {}
204
+ for k, v in state_dict.items():
205
+ if k in model_dict:
206
+ pretrained_dict[k] = v
207
+ model_dict.update(pretrained_dict)
208
+ model.load_state_dict(model_dict)
209
+
210
+ model.eval()
211
+ pretrained_model_path = args.pretrained_model_path
212
+ diffusion = create_diffusion(str(args.num_sampling_steps))
213
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
214
+ text_encoder = TextEmbedder(pretrained_model_path).to(device)
215
+ if args.use_fp16:
216
+ print('Warnning: using half percision for inferencing!')
217
+ vae.to(dtype=torch.float16)
218
+ model.to(dtype=torch.float16)
219
+ text_encoder.to(dtype=torch.float16)
220
+
221
+ # prompt:
222
+ prompt = args.text_prompt
223
+ if prompt ==[]:
224
+ prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
225
+ else:
226
+ prompt = prompt[0]
227
+ prompt_base = prompt.replace(' ','_')
228
+ prompt = prompt + args.additional_prompt
229
+
230
+ if not os.path.exists(os.path.join(args.save_path)):
231
+ os.makedirs(os.path.join(args.save_path))
232
+ video_input, researve_frames = get_input(args) # f,c,h,w
233
+ video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
234
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
235
+ masked_video = video_input * (mask == 0)
236
+
237
+ video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
238
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
239
+ save_video_path = os.path.join(args.save_path, prompt_base+ '.mp4')
240
+ torchvision.io.write_video(save_video_path, video_, fps=8)
241
+ print(f'save in {save_video_path}')
242
+
243
+
244
+ if __name__ == "__main__":
245
+ parser = argparse.ArgumentParser()
246
+ parser.add_argument("--config", type=str, default="configs/with_mask_sample.yaml")
247
+ args = parser.parse_args()
248
+ omega_conf = OmegaConf.load(args.config)
249
+ main(omega_conf)