Spaces:
fantos
/
Restarting on Zero

zxl commited on
Commit
07c6a04
·
1 Parent(s): bd6e6ad

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CONTRIBUTING.md +37 -0
  2. LICENSE +0 -0
  3. README.md +16 -8
  4. app.py +508 -0
  5. docs/dsp.md +25 -0
  6. docs/pab.md +121 -0
  7. eval/pab/commom_metrics/README.md +6 -0
  8. eval/pab/commom_metrics/__init__.py +0 -0
  9. eval/pab/commom_metrics/calculate_lpips.py +97 -0
  10. eval/pab/commom_metrics/calculate_psnr.py +90 -0
  11. eval/pab/commom_metrics/calculate_ssim.py +116 -0
  12. eval/pab/commom_metrics/eval.py +160 -0
  13. eval/pab/experiments/__init__.py +0 -0
  14. eval/pab/experiments/attention_ablation.py +60 -0
  15. eval/pab/experiments/components_ablation.py +46 -0
  16. eval/pab/experiments/latte.py +57 -0
  17. eval/pab/experiments/opensora.py +44 -0
  18. eval/pab/experiments/opensora_plan.py +57 -0
  19. eval/pab/experiments/utils.py +22 -0
  20. eval/pab/vbench/VBench_full_info.json +0 -0
  21. eval/pab/vbench/cal_vbench.py +154 -0
  22. eval/pab/vbench/run_vbench.py +52 -0
  23. examples/cogvideo/sample.py +14 -0
  24. examples/latte/sample.py +24 -0
  25. examples/open_sora/sample.py +24 -0
  26. examples/open_sora_plan/sample.py +24 -0
  27. requirements.txt +25 -0
  28. setup.py +55 -0
  29. tests/__init__.py +0 -0
  30. videosys/__init__.py +19 -0
  31. videosys/core/__init__.py +0 -0
  32. videosys/core/comm.py +420 -0
  33. videosys/core/engine.py +132 -0
  34. videosys/core/mp_utils.py +270 -0
  35. videosys/core/pab_mgr.py +364 -0
  36. videosys/core/parallel_mgr.py +119 -0
  37. videosys/core/pipeline.py +34 -0
  38. videosys/core/shardformer/__init__.py +0 -0
  39. videosys/core/shardformer/t5/__init__.py +0 -0
  40. videosys/core/shardformer/t5/modeling.py +39 -0
  41. videosys/core/shardformer/t5/policy.py +68 -0
  42. videosys/datasets/dataloader.py +94 -0
  43. videosys/datasets/image_transform.py +42 -0
  44. videosys/datasets/video_transform.py +441 -0
  45. videosys/diffusion/__init__.py +41 -0
  46. videosys/diffusion/diffusion_utils.py +79 -0
  47. videosys/diffusion/gaussian_diffusion.py +829 -0
  48. videosys/diffusion/respace.py +119 -0
  49. videosys/diffusion/timestep_sampler.py +143 -0
  50. videosys/models/__init__.py +0 -0
CONTRIBUTING.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Coding Standards
2
+
3
+ ### Unit Tests
4
+ We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
5
+
6
+ To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
7
+ ```bash
8
+ pip install -r requirements/requirements-test.txt
9
+ ```
10
+ If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
11
+
12
+ If you only want to run CPU tests, you can run
13
+
14
+ ```bash
15
+ pytest -m cpu tests/
16
+ ```
17
+
18
+ If you have 8 GPUs on your machine, you can run the full test
19
+
20
+ ```bash
21
+ pytest tests/
22
+ ```
23
+
24
+ If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch.
25
+
26
+
27
+ ### Code Style
28
+
29
+ We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below.
30
+
31
+ ```shell
32
+ # these commands are executed under the Colossal-AI directory
33
+ pip install pre-commit
34
+ pre-commit install
35
+ ```
36
+
37
+ Code format checking will be automatically executed when you commit your changes.
LICENSE ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,12 +1,20 @@
1
  ---
2
- title: Demo
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.42.0
 
 
 
8
  app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
1
  ---
2
+ title: VideoSys-CogVideoX
3
+ emoji: 🎥
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
+ suggested_hardware: a10g-large
9
+ suggested_storage: large
10
+ app_port: 7860
11
  app_file: app.py
12
+ models:
13
+ - THUDM/CogVideoX-2b
14
+ tags:
15
+ - cogvideox
16
+ - video-generation
17
+ - thudm
18
+ short_description: Text-to-Video
19
+ disable_embedding: false
20
+ ---
app.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # import gradio as gr
2
+ # # from videosys import CogVideoConfig, VideoSysEngine
3
+ # # import tempfile
4
+ # # import os
5
+ # # import logging
6
+ # # import uuid
7
+
8
+ # # logging.basicConfig(level=logging.INFO)
9
+ # # logger = logging.getLogger(__name__)
10
+
11
+ # # config = CogVideoConfig(world_size=1)
12
+ # # engine = VideoSysEngine(config)
13
+
14
+ # # def generate_video(prompt):
15
+ # # try:
16
+ # # video = engine.generate(prompt).video[0]
17
+
18
+ # # # 使用临时文件和唯一标识符
19
+ # # with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
20
+ # # temp_filename = temp_file.name
21
+ # # unique_filename = f"{uuid.uuid4().hex}.mp4"
22
+ # # output_path = os.path.join(tempfile.gettempdir(), unique_filename)
23
+
24
+ # # engine.save_video(video, output_path)
25
+
26
+ # # return output_path
27
+ # # except Exception as e:
28
+ # # logger.error(f"An error occurred: {str(e)}")
29
+ # # return None # 返回 None 而不是错误消息
30
+
31
+ # # iface = gr.Interface(
32
+ # # fn=generate_video,
33
+ # # inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
34
+ # # outputs=gr.Video(label="Generated Video"),
35
+ # # title="CogVideoX-2b: Text-to-Video Generation",
36
+ # # description="Enter a text prompt to generate a video using CogVideoX-2b."
37
+ # # )
38
+
39
+ # # iface.launch()
40
+
41
+
42
+ # from videosys import CogVideoConfig, VideoSysEngine
43
+ # from videosys.models.cogvideo.pipeline import CogVideoPABConfig
44
+ # import os
45
+
46
+ # import gradio as gr
47
+ # import numpy as np
48
+ # import torch
49
+ # from openai import OpenAI
50
+ # from time import time
51
+ # import tempfile
52
+ # import uuid
53
+ # import logging
54
+
55
+ # logging.basicConfig(level=logging.INFO)
56
+ # logger = logging.getLogger(__name__)
57
+
58
+ # dtype = torch.bfloat16
59
+ # sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
60
+
61
+ # For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
62
+ # There are a few rules to follow:
63
+
64
+ # You will only ever output a single video description per user request.
65
+
66
+ # When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
67
+ # Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
68
+
69
+ # Video descriptions must have the same num of words as examples below. Extra words will be ignored.
70
+ # """
71
+
72
+ # def convert_prompt(prompt: str, retry_times: int = 3) -> str:
73
+ # if not os.environ.get("OPENAI_API_KEY"):
74
+ # return prompt
75
+ # client = OpenAI()
76
+ # text = prompt.strip()
77
+
78
+ # for i in range(retry_times):
79
+ # response = client.chat.completions.create(
80
+ # messages=[
81
+ # {"role": "system", "content": sys_prompt},
82
+ # {
83
+ # "role": "user",
84
+ # "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
85
+ # },
86
+ # {
87
+ # "role": "assistant",
88
+ # "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
89
+ # },
90
+ # {
91
+ # "role": "user",
92
+ # "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
93
+ # },
94
+ # {
95
+ # "role": "assistant",
96
+ # "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
97
+ # },
98
+ # {
99
+ # "role": "user",
100
+ # "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
101
+ # },
102
+ # {
103
+ # "role": "assistant",
104
+ # "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
105
+ # },
106
+ # {
107
+ # "role": "user",
108
+ # "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
109
+ # },
110
+ # ],
111
+ # model="glm-4-0520",
112
+ # temperature=0.01,
113
+ # top_p=0.7,
114
+ # stream=False,
115
+ # max_tokens=250,
116
+ # )
117
+ # if response.choices:
118
+ # return response.choices[0].message.content
119
+ # return prompt
120
+
121
+ # def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
122
+ # pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
123
+ # config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
124
+ # engine = VideoSysEngine(config)
125
+ # return engine
126
+
127
+
128
+
129
+ # def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
130
+ # try:
131
+ # video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
132
+
133
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
134
+ # temp_file.name
135
+ # unique_filename = f"{uuid.uuid4().hex}.mp4"
136
+ # output_path = os.path.join(tempfile.gettempdir(), unique_filename)
137
+
138
+ # engine.save_video(video, output_path)
139
+ # return output_path
140
+ # except Exception as e:
141
+ # logger.error(f"An error occurred: {str(e)}")
142
+ # return None
143
+
144
+
145
+
146
+ # with gr.Blocks() as demo:
147
+ # gr.Markdown("""
148
+ # <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
149
+ # VideoSys Huggingface Space🤗
150
+ # </div>
151
+ # <div style="text-align: center;">
152
+ # <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">🌐 Github</a>
153
+ # </div>
154
+
155
+ # <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
156
+ # ⚠️ This demo is for academic research and experiential use only.
157
+ # Users should strictly adhere to local laws and ethics.
158
+ # </div>
159
+ # <div style="text-align: center; font-size: 15px; font-weight: bold; color: magenta; margin-bottom: 20px;">
160
+ # 💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
161
+ # </div>
162
+ # """)
163
+ # with gr.Row():
164
+ # with gr.Column():
165
+ # prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="a bear hunting for prey", lines=5)
166
+ # with gr.Row():
167
+ # gr.Markdown(
168
+ # "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
169
+ # )
170
+ # enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
171
+
172
+ # with gr.Column():
173
+ # gr.Markdown(
174
+ # "**Optional Parameters** (default values are recommended)<br>"
175
+ # "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
176
+ # "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>"
177
+ # )
178
+ # with gr.Row():
179
+ # num_inference_steps = gr.Number(label="Inference Steps", value=50)
180
+ # guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
181
+ # pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
182
+ # pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
183
+ # with gr.Row():
184
+ # generate_button = gr.Button("🎬 Generate Video")
185
+ # generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
186
+
187
+ # with gr.Column():
188
+ # with gr.Row():
189
+ # video_output = gr.Video(label="CogVideoX", width=720, height=480)
190
+ # with gr.Row():
191
+ # download_video_button = gr.File(label="📥 Download Video", visible=False)
192
+ # elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
193
+ # with gr.Row():
194
+ # video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
195
+ # with gr.Row():
196
+ # download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
197
+ # elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
198
+
199
+ # def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
200
+ # # tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
201
+ # engine = load_model()
202
+ # t = time()
203
+ # video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
204
+ # elapsed_time = time() - t
205
+ # video_update = gr.update(visible=True, value=video_path)
206
+ # elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
207
+
208
+ # return video_path, video_update, elapsed_time
209
+
210
+ # def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
211
+ # # tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
212
+ # threshold = [int(i) for i in threshold.split(",")]
213
+ # gap = int(gap)
214
+ # engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
215
+ # t = time()
216
+ # video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
217
+ # elapsed_time = time() - t
218
+ # video_update = gr.update(visible=True, value=video_path)
219
+ # elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
220
+
221
+ # return video_path, video_update, elapsed_time
222
+
223
+
224
+ # def enhance_prompt_func(prompt):
225
+ # return convert_prompt(prompt, retry_times=1)
226
+
227
+ # generate_button.click(
228
+ # generate_vanilla,
229
+ # inputs=[prompt, num_inference_steps, guidance_scale],
230
+ # outputs=[video_output, download_video_button, elapsed_time],
231
+ # )
232
+
233
+ # generate_button_vs.click(
234
+ # generate_vs,
235
+ # inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
236
+ # outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
237
+ # )
238
+
239
+ # enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
240
+
241
+ # if __name__ == "__main__":
242
+ # demo.launch()
243
+
244
+
245
+
246
+ import gradio as gr
247
+ from videosys import CogVideoConfig, VideoSysEngine
248
+ from videosys.models.cogvideo.pipeline import CogVideoPABConfig
249
+ import os
250
+ import numpy as np
251
+ import torch
252
+ from openai import OpenAI
253
+ from time import time
254
+ import tempfile
255
+ import uuid
256
+ import logging
257
+
258
+ logging.basicConfig(level=logging.INFO)
259
+ logger = logging.getLogger(__name__)
260
+
261
+ dtype = torch.bfloat16
262
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
263
+
264
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
265
+ There are a few rules to follow:
266
+
267
+ You will only ever output a single video description per user request.
268
+
269
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
270
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
271
+
272
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
273
+ """
274
+
275
+ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
276
+ if not os.environ.get("OPENAI_API_KEY"):
277
+ return prompt
278
+ client = OpenAI()
279
+ text = prompt.strip()
280
+
281
+ for i in range(retry_times):
282
+ response = client.chat.completions.create(
283
+ messages=[
284
+ {"role": "system", "content": sys_prompt},
285
+ {
286
+ "role": "user",
287
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
288
+ },
289
+ {
290
+ "role": "assistant",
291
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
292
+ },
293
+ {
294
+ "role": "user",
295
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
296
+ },
297
+ {
298
+ "role": "assistant",
299
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
300
+ },
301
+ {
302
+ "role": "user",
303
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
304
+ },
305
+ {
306
+ "role": "assistant",
307
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
308
+ },
309
+ {
310
+ "role": "user",
311
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
312
+ },
313
+ ],
314
+ model="glm-4-0520",
315
+ temperature=0.01,
316
+ top_p=0.7,
317
+ stream=False,
318
+ max_tokens=250,
319
+ )
320
+ if response.choices:
321
+ return response.choices[0].message.content
322
+ return prompt
323
+
324
+ def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
325
+ pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
326
+ config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
327
+ engine = VideoSysEngine(config)
328
+ return engine
329
+
330
+ def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
331
+ try:
332
+ video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
333
+
334
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
335
+ temp_file.name
336
+ unique_filename = f"{uuid.uuid4().hex}.mp4"
337
+ output_path = os.path.join(tempfile.gettempdir(), unique_filename)
338
+
339
+ engine.save_video(video, output_path)
340
+ return output_path
341
+ except Exception as e:
342
+ logger.error(f"An error occurred: {str(e)}")
343
+ return None
344
+
345
+ css = """
346
+ body {
347
+ font-family: Arial, sans-serif;
348
+ line-height: 1.6;
349
+ color: #333;
350
+ max-width: 1200px;
351
+ margin: 0 auto;
352
+ padding: 20px;
353
+ }
354
+
355
+ .container {
356
+ display: flex;
357
+ flex-direction: column;
358
+ gap: 20px;
359
+ }
360
+
361
+ .row {
362
+ display: flex;
363
+ flex-wrap: wrap;
364
+ gap: 20px;
365
+ }
366
+
367
+ .column {
368
+ flex: 1;
369
+ min-width: 0;
370
+ }
371
+
372
+ .textbox, .number-input, button {
373
+ width: 100%;
374
+ padding: 10px;
375
+ margin-bottom: 10px;
376
+ border: 1px solid #ddd;
377
+ border-radius: 4px;
378
+ }
379
+
380
+ button {
381
+ background-color: #4CAF50;
382
+ color: white;
383
+ border: none;
384
+ cursor: pointer;
385
+ transition: background-color 0.3s;
386
+ }
387
+
388
+ button:hover {
389
+ background-color: #45a049;
390
+ }
391
+
392
+ .video-output {
393
+ width: 100%;
394
+ max-width: 720px;
395
+ height: auto;
396
+ margin: 0 auto;
397
+ }
398
+
399
+ @media (max-width: 768px) {
400
+ .row {
401
+ flex-direction: column;
402
+ }
403
+
404
+ .column {
405
+ width: 100%;
406
+ }
407
+
408
+ .video-output {
409
+ width: 100%;
410
+ height: auto;
411
+ }
412
+ }
413
+ """
414
+
415
+ with gr.Blocks(css=css) as demo:
416
+ gr.HTML("""
417
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
418
+ VideoSys Huggingface Space🤗
419
+ </div>
420
+ <div style="text-align: center;">
421
+ <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">🌐 Github</a>
422
+ </div>
423
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
424
+ ⚠️ This demo is for academic research and experiential use only.
425
+ Users should strictly adhere to local laws and ethics.
426
+ </div>
427
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: magenta; margin-bottom: 20px;">
428
+ 💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
429
+ </div>
430
+ """)
431
+
432
+ with gr.Row():
433
+ with gr.Column():
434
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="a bear hunting for prey", lines=5)
435
+ with gr.Row():
436
+ gr.Markdown(
437
+ "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
438
+ )
439
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
440
+
441
+ with gr.Column():
442
+ gr.Markdown(
443
+ "**Optional Parameters** (default values are recommended)<br>"
444
+ "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
445
+ "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>"
446
+ )
447
+ with gr.Row():
448
+ num_inference_steps = gr.Number(label="Inference Steps", value=50)
449
+ guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
450
+ pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
451
+ pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
452
+ with gr.Row():
453
+ generate_button = gr.Button("🎬 Generate Video")
454
+ generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
455
+
456
+ with gr.Column():
457
+ with gr.Row():
458
+ video_output = gr.Video(label="CogVideoX", width=720, height=480)
459
+ with gr.Row():
460
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
461
+ elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
462
+ with gr.Row():
463
+ video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
464
+ with gr.Row():
465
+ download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
466
+ elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
467
+
468
+ def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
469
+ engine = load_model()
470
+ t = time()
471
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
472
+ elapsed_time = time() - t
473
+ video_update = gr.update(visible=True, value=video_path)
474
+ elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
475
+
476
+ return video_path, video_update, elapsed_time
477
+
478
+ def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
479
+ threshold = [int(i) for i in threshold.split(",")]
480
+ gap = int(gap)
481
+ engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
482
+ t = time()
483
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
484
+ elapsed_time = time() - t
485
+ video_update = gr.update(visible=True, value=video_path)
486
+ elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
487
+
488
+ return video_path, video_update, elapsed_time
489
+
490
+ def enhance_prompt_func(prompt):
491
+ return convert_prompt(prompt, retry_times=1)
492
+
493
+ generate_button.click(
494
+ generate_vanilla,
495
+ inputs=[prompt, num_inference_steps, guidance_scale],
496
+ outputs=[video_output, download_video_button, elapsed_time],
497
+ )
498
+
499
+ generate_button_vs.click(
500
+ generate_vs,
501
+ inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
502
+ outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
503
+ )
504
+
505
+ enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
506
+
507
+ if __name__ == "__main__":
508
+ demo.launch()
docs/dsp.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DSP
2
+
3
+ paper: https://arxiv.org/abs/2403.10266
4
+
5
+ ![dsp_overview](../assets/figures/dsp_overview.png)
6
+
7
+
8
+ DSP (Dynamic Sequence Parallelism) is a novel, elegant and super efficient sequence parallelism for [OpenSora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
9
+
10
+ The key idea is to dynamically switch the parallelism dimension according to the current computation stage, leveraging the potential characteristics of multi-dimensional transformers. Compared with splitting head and sequence dimension as previous methods, it can reduce at least 75% of communication cost.
11
+
12
+ It achieves **3x** speed for training and **2x** speed for inference in OpenSora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of OpenSora is:
13
+
14
+ | Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
15
+ | ------ | ------ | ------ | ------ |
16
+ | Latency(s) | 106 | 45 | 22 |
17
+
18
+ The following is DSP's end-to-end throughput for training of OpenSora:
19
+
20
+ ![dsp_overview](../assets/figures/dsp_exp.png)
21
+
22
+
23
+ ### Usage
24
+
25
+ DSP is currently supported for: OpenSora, OpenSoraPlan and Latte. To enable DSP, you just need to launch with multiple GPUs.
docs/pab.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Attention Broadcast(PAB)
2
+
3
+ [[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)]
4
+
5
+ Pyramid Attention Broadcast(PAB)(#pyramid-attention-broadcastpab)
6
+ - [Pyramid Attention Broadcast(PAB)](#pyramid-attention-broadcastpab)
7
+ - [Insights](#insights)
8
+ - [Pyramid Attention Broadcast (PAB) Mechanism](#pyramid-attention-broadcast-pab-mechanism)
9
+ - [Experimental Results](#experimental-results)
10
+ - [Usage](#usage)
11
+ - [Supported Models](#supported-models)
12
+ - [Configuration for PAB](#configuration-for-pab)
13
+ - [Parameters](#parameters)
14
+ - [Example Configuration](#example-configuration)
15
+
16
+
17
+ We introduce Pyramid Attention Broadcast (PAB), the first approach that achieves real-time DiT-based video generation. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including Open-Sora, Open-Sora-Plan, and Latte. Notably, as a training-free approach, PAB can enpower any future DiT-based video generation models with real-time capabilities.
18
+
19
+ ## Insights
20
+
21
+ ![method](../assets/figures/pab_motivation.png)
22
+
23
+ Our study reveals two key insights of three **attention mechanisms** within video diffusion transformers:
24
+ - First, attention differences across time steps exhibit a U-shaped pattern, with significant variations occurring during the first and last 15% of steps, while the middle 70% of steps show very stable, minor differences.
25
+ - Second, within the stable middle segment, the variability differs among attention types:
26
+ - **Spatial attention** varies the most, involving high-frequency elements like edges and textures;
27
+ - **Temporal attention** exhibits mid-frequency variations related to movements and dynamics in videos;
28
+ - **Cross-modal attention** is the most stable, linking text with video content, analogous to low-frequency signals reflecting textual semantics.
29
+
30
+ ## Pyramid Attention Broadcast (PAB) Mechanism
31
+
32
+ ![method](../assets/figures/pab_method.png)
33
+
34
+ Building on these insights, we propose a **pyramid attention broadcast(PAB)** mechanism to minimize unnecessary computations and optimize the utility of each attention module, as shown in Figure[xx figure] below.
35
+
36
+ In the middle segment, we broadcast one step's attention outputs to its subsequent several steps, thereby significantly reducing the computational cost on attention modules.
37
+
38
+ For more efficient broadcast and minimum influence to effect, we set varied broadcast ranges for different attentions based on their stability and differences.
39
+ **The smaller the variation in attention, the broader the potential broadcast range.**
40
+
41
+
42
+ ## Experimental Results
43
+ Here are the results of our experiments, more results are shown in https://oahzxl.github.io/PAB:
44
+
45
+ ![pab_vis](../assets/figures/pab_vis.png)
46
+
47
+
48
+ ## Usage
49
+
50
+ ### Supported Models
51
+
52
+ PAB currently supports Open-Sora, Open-Sora-Plan, and Latte.
53
+
54
+ ### Configuration for PAB
55
+
56
+ To efficiently use the Pyramid Attention Broadcast (PAB) mechanism, configure the following parameters to control the broadcasting for different attention types. This helps reduce computational costs by skipping certain steps based on attention stability.
57
+
58
+ #### Parameters
59
+
60
+ - **spatial_broadcast**: Enable or disable broadcasting for spatial attention.
61
+ - Type: `True` or `False`
62
+
63
+ - **spatial_threshold**: Set the range of diffusion steps within which spatial attention is applied.
64
+ - Format: `[min_value, max_value]`
65
+
66
+ - **spatial_gap**: Number of blocks in model to skip during broadcasting for spatial attention.
67
+ - Type: Integer
68
+
69
+ - **temporal_broadcast**: Enable or disable broadcasting for temporal attention.
70
+ - Type: `True` or `False`
71
+
72
+ - **temporal_threshold**: Set the range of diffusion steps within which temporal attention is applied.
73
+ - Format: `[min_value, max_value]`
74
+
75
+ - **temporal_gap**: Number of steps to skip during broadcasting for temporal attention.
76
+ - Type: Integer
77
+
78
+ - **cross_broadcast**: Enable or disable broadcasting for cross-modal attention.
79
+ - Type: `True` or `False`
80
+
81
+ - **cross_threshold**: Set the range of diffusion steps within which cross-modal attention is applied.
82
+ - Format: `[min_value, max_value]`
83
+
84
+ - **cross_gap**: Number of steps to skip during broadcasting for cross-modal attention.
85
+ - Type: Integer
86
+
87
+ #### Example Configuration
88
+
89
+ ```yaml
90
+ spatial_broadcast: True
91
+ spatial_threshold: [100, 800]
92
+ spatial_gap: 2
93
+
94
+ temporal_broadcast: True
95
+ temporal_threshold: [100, 800]
96
+ temporal_gap: 3
97
+
98
+ cross_broadcast: True
99
+ cross_threshold: [100, 900]
100
+ cross_gap: 5
101
+ ```
102
+
103
+ Explanation:
104
+
105
+ - **Spatial Attention**:
106
+ - Broadcasting enabled (`spatial_broadcast: True`)
107
+ - Applied within the threshold range of 100 to 800
108
+ - Skips every 2 steps (`spatial_gap: 2`)
109
+ - Active within the first 28 steps (`spatial_block: [0, 28]`)
110
+
111
+ - **Temporal Attention**:
112
+ - Broadcasting enabled (`temporal_broadcast: True`)
113
+ - Applied within the threshold range of 100 to 800
114
+ - Skips every 3 steps (`temporal_gap: 3`)
115
+
116
+ - **Cross-Modal Attention**:
117
+ - Broadcasting enabled (`cross_broadcast: True`)
118
+ - Applied within the threshold range of 100 to 900
119
+ - Skips every 5 steps (`cross_gap: 5`)
120
+
121
+ Adjust these settings based on your specific needs to optimize the performance of each attention mechanism.
eval/pab/commom_metrics/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Common metrics
2
+
3
+ Include LPIPS, PSNR and SSIM.
4
+
5
+ The code is adapted from [common_metrics_on_video_quality
6
+ ](https://github.com/JunyaoHu/common_metrics_on_video_quality).
eval/pab/commom_metrics/__init__.py ADDED
File without changes
eval/pab/commom_metrics/calculate_lpips.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lpips
2
+ import numpy as np
3
+ import torch
4
+
5
+ spatial = True # Return a spatial map of perceptual distance.
6
+
7
+ # Linearly calibrated models (LPIPS)
8
+ loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
9
+ # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
10
+
11
+
12
+ def trans(x):
13
+ # if greyscale images add channel
14
+ if x.shape[-3] == 1:
15
+ x = x.repeat(1, 1, 3, 1, 1)
16
+
17
+ # value range [0, 1] -> [-1, 1]
18
+ x = x * 2 - 1
19
+
20
+ return x
21
+
22
+
23
+ def calculate_lpips(videos1, videos2, device):
24
+ # image should be RGB, IMPORTANT: normalized to [-1,1]
25
+
26
+ assert videos1.shape == videos2.shape
27
+
28
+ # videos [batch_size, timestamps, channel, h, w]
29
+
30
+ # support grayscale input, if grayscale -> channel*3
31
+ # value range [0, 1] -> [-1, 1]
32
+ videos1 = trans(videos1)
33
+ videos2 = trans(videos2)
34
+
35
+ lpips_results = []
36
+
37
+ for video_num in range(videos1.shape[0]):
38
+ # get a video
39
+ # video [timestamps, channel, h, w]
40
+ video1 = videos1[video_num]
41
+ video2 = videos2[video_num]
42
+
43
+ lpips_results_of_a_video = []
44
+ for clip_timestamp in range(len(video1)):
45
+ # get a img
46
+ # img [timestamps[x], channel, h, w]
47
+ # img [channel, h, w] tensor
48
+
49
+ img1 = video1[clip_timestamp].unsqueeze(0).to(device)
50
+ img2 = video2[clip_timestamp].unsqueeze(0).to(device)
51
+
52
+ loss_fn.to(device)
53
+
54
+ # calculate lpips of a video
55
+ lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
56
+ lpips_results.append(lpips_results_of_a_video)
57
+
58
+ lpips_results = np.array(lpips_results)
59
+
60
+ lpips = {}
61
+ lpips_std = {}
62
+
63
+ for clip_timestamp in range(len(video1)):
64
+ lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
65
+ lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
66
+
67
+ result = {
68
+ "value": lpips,
69
+ "value_std": lpips_std,
70
+ "video_setting": video1.shape,
71
+ "video_setting_name": "time, channel, heigth, width",
72
+ }
73
+
74
+ return result
75
+
76
+
77
+ # test code / using example
78
+
79
+
80
+ def main():
81
+ NUMBER_OF_VIDEOS = 8
82
+ VIDEO_LENGTH = 50
83
+ CHANNEL = 3
84
+ SIZE = 64
85
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
86
+ videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
87
+ device = torch.device("cuda")
88
+ # device = torch.device("cpu")
89
+
90
+ import json
91
+
92
+ result = calculate_lpips(videos1, videos2, device)
93
+ print(json.dumps(result, indent=4))
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
eval/pab/commom_metrics/calculate_psnr.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def img_psnr(img1, img2):
8
+ # [0,1]
9
+ # compute mse
10
+ # mse = np.mean((img1-img2)**2)
11
+ mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
12
+ # compute psnr
13
+ if mse < 1e-10:
14
+ return 100
15
+ psnr = 20 * math.log10(1 / math.sqrt(mse))
16
+ return psnr
17
+
18
+
19
+ def trans(x):
20
+ return x
21
+
22
+
23
+ def calculate_psnr(videos1, videos2):
24
+ # videos [batch_size, timestamps, channel, h, w]
25
+
26
+ assert videos1.shape == videos2.shape
27
+
28
+ videos1 = trans(videos1)
29
+ videos2 = trans(videos2)
30
+
31
+ psnr_results = []
32
+
33
+ for video_num in range(videos1.shape[0]):
34
+ # get a video
35
+ # video [timestamps, channel, h, w]
36
+ video1 = videos1[video_num]
37
+ video2 = videos2[video_num]
38
+
39
+ psnr_results_of_a_video = []
40
+ for clip_timestamp in range(len(video1)):
41
+ # get a img
42
+ # img [timestamps[x], channel, h, w]
43
+ # img [channel, h, w] numpy
44
+
45
+ img1 = video1[clip_timestamp].numpy()
46
+ img2 = video2[clip_timestamp].numpy()
47
+
48
+ # calculate psnr of a video
49
+ psnr_results_of_a_video.append(img_psnr(img1, img2))
50
+
51
+ psnr_results.append(psnr_results_of_a_video)
52
+
53
+ psnr_results = np.array(psnr_results)
54
+
55
+ psnr = {}
56
+ psnr_std = {}
57
+
58
+ for clip_timestamp in range(len(video1)):
59
+ psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
60
+ psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
61
+
62
+ result = {
63
+ "value": psnr,
64
+ "value_std": psnr_std,
65
+ "video_setting": video1.shape,
66
+ "video_setting_name": "time, channel, heigth, width",
67
+ }
68
+
69
+ return result
70
+
71
+
72
+ # test code / using example
73
+
74
+
75
+ def main():
76
+ NUMBER_OF_VIDEOS = 8
77
+ VIDEO_LENGTH = 50
78
+ CHANNEL = 3
79
+ SIZE = 64
80
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
81
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
82
+
83
+ import json
84
+
85
+ result = calculate_psnr(videos1, videos2)
86
+ print(json.dumps(result, indent=4))
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
eval/pab/commom_metrics/calculate_ssim.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def ssim(img1, img2):
7
+ C1 = 0.01**2
8
+ C2 = 0.03**2
9
+ img1 = img1.astype(np.float64)
10
+ img2 = img2.astype(np.float64)
11
+ kernel = cv2.getGaussianKernel(11, 1.5)
12
+ window = np.outer(kernel, kernel.transpose())
13
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
14
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
15
+ mu1_sq = mu1**2
16
+ mu2_sq = mu2**2
17
+ mu1_mu2 = mu1 * mu2
18
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
19
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
20
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
21
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
22
+ return ssim_map.mean()
23
+
24
+
25
+ def calculate_ssim_function(img1, img2):
26
+ # [0,1]
27
+ # ssim is the only metric extremely sensitive to gray being compared to b/w
28
+ if not img1.shape == img2.shape:
29
+ raise ValueError("Input images must have the same dimensions.")
30
+ if img1.ndim == 2:
31
+ return ssim(img1, img2)
32
+ elif img1.ndim == 3:
33
+ if img1.shape[0] == 3:
34
+ ssims = []
35
+ for i in range(3):
36
+ ssims.append(ssim(img1[i], img2[i]))
37
+ return np.array(ssims).mean()
38
+ elif img1.shape[0] == 1:
39
+ return ssim(np.squeeze(img1), np.squeeze(img2))
40
+ else:
41
+ raise ValueError("Wrong input image dimensions.")
42
+
43
+
44
+ def trans(x):
45
+ return x
46
+
47
+
48
+ def calculate_ssim(videos1, videos2):
49
+ # videos [batch_size, timestamps, channel, h, w]
50
+
51
+ assert videos1.shape == videos2.shape
52
+
53
+ videos1 = trans(videos1)
54
+ videos2 = trans(videos2)
55
+
56
+ ssim_results = []
57
+
58
+ for video_num in range(videos1.shape[0]):
59
+ # get a video
60
+ # video [timestamps, channel, h, w]
61
+ video1 = videos1[video_num]
62
+ video2 = videos2[video_num]
63
+
64
+ ssim_results_of_a_video = []
65
+ for clip_timestamp in range(len(video1)):
66
+ # get a img
67
+ # img [timestamps[x], channel, h, w]
68
+ # img [channel, h, w] numpy
69
+
70
+ img1 = video1[clip_timestamp].numpy()
71
+ img2 = video2[clip_timestamp].numpy()
72
+
73
+ # calculate ssim of a video
74
+ ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
75
+
76
+ ssim_results.append(ssim_results_of_a_video)
77
+
78
+ ssim_results = np.array(ssim_results)
79
+
80
+ ssim = {}
81
+ ssim_std = {}
82
+
83
+ for clip_timestamp in range(len(video1)):
84
+ ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
85
+ ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
86
+
87
+ result = {
88
+ "value": ssim,
89
+ "value_std": ssim_std,
90
+ "video_setting": video1.shape,
91
+ "video_setting_name": "time, channel, heigth, width",
92
+ }
93
+
94
+ return result
95
+
96
+
97
+ # test code / using example
98
+
99
+
100
+ def main():
101
+ NUMBER_OF_VIDEOS = 8
102
+ VIDEO_LENGTH = 50
103
+ CHANNEL = 3
104
+ SIZE = 64
105
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
106
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
107
+ torch.device("cuda")
108
+
109
+ import json
110
+
111
+ result = calculate_ssim(videos1, videos2)
112
+ print(json.dumps(result, indent=4))
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
eval/pab/commom_metrics/eval.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import imageio
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ import tqdm
8
+ from calculate_lpips import calculate_lpips
9
+ from calculate_psnr import calculate_psnr
10
+ from calculate_ssim import calculate_ssim
11
+
12
+
13
+ def load_videos(directory, video_ids, file_extension):
14
+ videos = []
15
+ for video_id in video_ids:
16
+ video_path = os.path.join(directory, f"{video_id}.{file_extension}")
17
+ if os.path.exists(video_path):
18
+ video = load_video(video_path) # Define load_video based on how videos are stored
19
+ videos.append(video)
20
+ else:
21
+ raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
22
+ return videos
23
+
24
+
25
+ def load_video(video_path):
26
+ """
27
+ Load a video from the given path and convert it to a PyTorch tensor.
28
+ """
29
+ # Read the video using imageio
30
+ reader = imageio.get_reader(video_path, "ffmpeg")
31
+
32
+ # Extract frames and convert to a list of tensors
33
+ frames = []
34
+ for frame in reader:
35
+ # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
36
+ frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
37
+ frames.append(frame_tensor)
38
+
39
+ # Stack the list of tensors into a single tensor with shape (T, C, H, W)
40
+ video_tensor = torch.stack(frames)
41
+
42
+ return video_tensor
43
+
44
+
45
+ def resize_video(video, target_height, target_width):
46
+ resized_frames = []
47
+ for frame in video:
48
+ resized_frame = F.resize(frame, [target_height, target_width])
49
+ resized_frames.append(resized_frame)
50
+ return torch.stack(resized_frames)
51
+
52
+
53
+ def preprocess_eval_video(eval_video, generated_video_shape):
54
+ T_gen, _, H_gen, W_gen = generated_video_shape
55
+ T_eval, _, H_eval, W_eval = eval_video.shape
56
+
57
+ if T_eval < T_gen:
58
+ raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
59
+
60
+ if H_eval < H_gen or W_eval < W_gen:
61
+ # Resize the video maintaining the aspect ratio
62
+ resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
63
+ resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
64
+ eval_video = resize_video(eval_video, resize_height, resize_width)
65
+ # Recalculate the dimensions
66
+ T_eval, _, H_eval, W_eval = eval_video.shape
67
+
68
+ # Center crop
69
+ start_h = (H_eval - H_gen) // 2
70
+ start_w = (W_eval - W_gen) // 2
71
+ cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
72
+
73
+ return cropped_video
74
+
75
+
76
+ def main(args):
77
+ device = "cuda"
78
+ gt_video_dir = args.gt_video_dir
79
+ generated_video_dir = args.generated_video_dir
80
+
81
+ video_ids = []
82
+ file_extension = "mp4"
83
+ for f in os.listdir(generated_video_dir):
84
+ if f.endswith(f".{file_extension}"):
85
+ video_ids.append(f.replace(f".{file_extension}", ""))
86
+ if not video_ids:
87
+ raise ValueError("No videos found in the generated video dataset. Exiting.")
88
+
89
+ print(f"Find {len(video_ids)} videos")
90
+ prompt_interval = 1
91
+ batch_size = 16
92
+ calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
93
+
94
+ lpips_results = []
95
+ psnr_results = []
96
+ ssim_results = []
97
+
98
+ total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
99
+
100
+ for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
101
+ gt_videos_tensor = []
102
+ generated_videos_tensor = []
103
+ for i in range(batch_size):
104
+ video_idx = idx * batch_size + i
105
+ if video_idx >= len(video_ids):
106
+ break
107
+ video_id = video_ids[video_idx]
108
+ generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
109
+ generated_videos_tensor.append(generated_video)
110
+ eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
111
+ gt_videos_tensor.append(eval_video)
112
+ gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
113
+ generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
114
+
115
+ if calculate_lpips_flag:
116
+ result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
117
+ result = result["value"].values()
118
+ result = sum(result) / len(result)
119
+ lpips_results.append(result)
120
+
121
+ if calculate_psnr_flag:
122
+ result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
123
+ result = result["value"].values()
124
+ result = sum(result) / len(result)
125
+ psnr_results.append(result)
126
+
127
+ if calculate_ssim_flag:
128
+ result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
129
+ result = result["value"].values()
130
+ result = sum(result) / len(result)
131
+ ssim_results.append(result)
132
+
133
+ if (idx + 1) % prompt_interval == 0:
134
+ out_str = ""
135
+ for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
136
+ result = sum(results) / len(results)
137
+ out_str += f"{name}: {result:.4f}, "
138
+ print(f"Processed {idx + 1} videos. {out_str[:-2]}")
139
+
140
+ out_str = ""
141
+ for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
142
+ result = sum(results) / len(results)
143
+ out_str += f"{name}: {result:.4f}, "
144
+ out_str = out_str[:-2]
145
+
146
+ # save
147
+ with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
148
+ f.write(out_str)
149
+
150
+ print(f"Processed all videos. {out_str}")
151
+
152
+
153
+ if __name__ == "__main__":
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument("--gt_video_dir", type=str)
156
+ parser.add_argument("--generated_video_dir", type=str)
157
+
158
+ args = parser.parse_args()
159
+
160
+ main(args)
eval/pab/experiments/__init__.py ADDED
File without changes
eval/pab/experiments/attention_ablation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import generate_func, read_prompt_list
2
+
3
+ import videosys
4
+ from videosys import OpenSoraConfig, OpenSoraPipeline
5
+ from videosys.models.open_sora import OpenSoraPABConfig
6
+
7
+
8
+ def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
9
+ pab_config = OpenSoraPABConfig(**pab_kwargs)
10
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
11
+ pipeline = OpenSoraPipeline(config)
12
+
13
+ generate_func(pipeline, prompt_list, output_dir)
14
+
15
+
16
+ def main(prompt_list):
17
+ # spatial
18
+ gap_list = [2, 3, 4, 5]
19
+ for gap in gap_list:
20
+ pab_kwargs = {
21
+ "spatial_broadcast": True,
22
+ "spatial_gap": gap,
23
+ "temporal_broadcast": False,
24
+ "cross_broadcast": False,
25
+ "mlp_skip": False,
26
+ }
27
+ output_dir = f"./samples/attention_ablation/spatial_g{gap}"
28
+ attention_ablation_func(pab_kwargs, prompt_list, output_dir)
29
+
30
+ # temporal
31
+ gap_list = [3, 4, 5, 6]
32
+ for gap in gap_list:
33
+ pab_kwargs = {
34
+ "spatial_broadcast": False,
35
+ "temporal_broadcast": True,
36
+ "temporal_gap": gap,
37
+ "cross_broadcast": False,
38
+ "mlp_skip": False,
39
+ }
40
+ output_dir = f"./samples/attention_ablation/temporal_g{gap}"
41
+ attention_ablation_func(pab_kwargs, prompt_list, output_dir)
42
+
43
+ # cross
44
+ gap_list = [5, 6, 7, 8]
45
+ for gap in gap_list:
46
+ pab_kwargs = {
47
+ "spatial_broadcast": False,
48
+ "temporal_broadcast": False,
49
+ "cross_broadcast": True,
50
+ "cross_gap": gap,
51
+ "mlp_skip": False,
52
+ }
53
+ output_dir = f"./samples/attention_ablation/cross_g{gap}"
54
+ attention_ablation_func(pab_kwargs, prompt_list, output_dir)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ videosys.initialize(42)
59
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
60
+ main(prompt_list)
eval/pab/experiments/components_ablation.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import generate_func, read_prompt_list
2
+
3
+ import videosys
4
+ from videosys import OpenSoraConfig, OpenSoraPipeline
5
+ from videosys.models.open_sora import OpenSoraPABConfig
6
+
7
+
8
+ def wo_spatial(prompt_list):
9
+ pab_config = OpenSoraPABConfig(spatial_broadcast=False)
10
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
11
+ pipeline = OpenSoraPipeline(config)
12
+
13
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_spatial")
14
+
15
+
16
+ def wo_temporal(prompt_list):
17
+ pab_config = OpenSoraPABConfig(temporal_broadcast=False)
18
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
19
+ pipeline = OpenSoraPipeline(config)
20
+
21
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_temporal")
22
+
23
+
24
+ def wo_cross(prompt_list):
25
+ pab_config = OpenSoraPABConfig(cross_broadcast=False)
26
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
27
+ pipeline = OpenSoraPipeline(config)
28
+
29
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_cross")
30
+
31
+
32
+ def wo_mlp(prompt_list):
33
+ pab_config = OpenSoraPABConfig(mlp_skip=False)
34
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
35
+ pipeline = OpenSoraPipeline(config)
36
+
37
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_mlp")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ videosys.initialize(42)
42
+ prompt_list = read_prompt_list("./vbench/VBench_full_info.json")
43
+ wo_spatial(prompt_list)
44
+ wo_temporal(prompt_list)
45
+ wo_cross(prompt_list)
46
+ wo_mlp(prompt_list)
eval/pab/experiments/latte.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import generate_func, read_prompt_list
2
+
3
+ import videosys
4
+ from videosys import LatteConfig, LattePipeline
5
+ from videosys.models.latte import LattePABConfig
6
+
7
+
8
+ def eval_base(prompt_list):
9
+ config = LatteConfig()
10
+ pipeline = LattePipeline(config)
11
+
12
+ generate_func(pipeline, prompt_list, "./samples/latte_base", loop=5)
13
+
14
+
15
+ def eval_pab1(prompt_list):
16
+ pab_config = LattePABConfig(
17
+ spatial_gap=2,
18
+ temporal_gap=3,
19
+ cross_gap=6,
20
+ )
21
+ config = LatteConfig(enable_pab=True, pab_config=pab_config)
22
+ pipeline = LattePipeline(config)
23
+
24
+ generate_func(pipeline, prompt_list, "./samples/latte_pab1", loop=5)
25
+
26
+
27
+ def eval_pab2(prompt_list):
28
+ pab_config = LattePABConfig(
29
+ spatial_gap=3,
30
+ temporal_gap=4,
31
+ cross_gap=7,
32
+ )
33
+ config = LatteConfig(enable_pab=True, pab_config=pab_config)
34
+ pipeline = LattePipeline(config)
35
+
36
+ generate_func(pipeline, prompt_list, "./samples/latte_pab2", loop=5)
37
+
38
+
39
+ def eval_pab3(prompt_list):
40
+ pab_config = LattePABConfig(
41
+ spatial_gap=4,
42
+ temporal_gap=6,
43
+ cross_gap=9,
44
+ )
45
+ config = LatteConfig(enable_pab=True, pab_config=pab_config)
46
+ pipeline = LattePipeline(config)
47
+
48
+ generate_func(pipeline, prompt_list, "./samples/latte_pab3", loop=5)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ videosys.initialize(42)
53
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
54
+ eval_base(prompt_list)
55
+ eval_pab1(prompt_list)
56
+ eval_pab2(prompt_list)
57
+ eval_pab3(prompt_list)
eval/pab/experiments/opensora.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import generate_func, read_prompt_list
2
+
3
+ import videosys
4
+ from videosys import OpenSoraConfig, OpenSoraPipeline
5
+ from videosys.models.open_sora import OpenSoraPABConfig
6
+
7
+
8
+ def eval_base(prompt_list):
9
+ config = OpenSoraConfig()
10
+ pipeline = OpenSoraPipeline(config)
11
+
12
+ generate_func(pipeline, prompt_list, "./samples/opensora_base", loop=5)
13
+
14
+
15
+ def eval_pab1(prompt_list):
16
+ config = OpenSoraConfig(enable_pab=True)
17
+ pipeline = OpenSoraPipeline(config)
18
+
19
+ generate_func(pipeline, prompt_list, "./samples/opensora_pab1", loop=5)
20
+
21
+
22
+ def eval_pab2(prompt_list):
23
+ pab_config = OpenSoraPABConfig(spatial_gap=3, temporal_gap=5, cross_gap=7)
24
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
25
+ pipeline = OpenSoraPipeline(config)
26
+
27
+ generate_func(pipeline, prompt_list, "./samples/opensora_pab2", loop=5)
28
+
29
+
30
+ def eval_pab3(prompt_list):
31
+ pab_config = OpenSoraPABConfig(spatial_gap=5, temporal_gap=7, cross_gap=9)
32
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
33
+ pipeline = OpenSoraPipeline(config)
34
+
35
+ generate_func(pipeline, prompt_list, "./samples/opensora_pab3", loop=5)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ videosys.initialize(42)
40
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
41
+ eval_base(prompt_list)
42
+ eval_pab1(prompt_list)
43
+ eval_pab2(prompt_list)
44
+ eval_pab3(prompt_list)
eval/pab/experiments/opensora_plan.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import generate_func, read_prompt_list
2
+
3
+ import videosys
4
+ from videosys import OpenSoraPlanConfig, OpenSoraPlanPipeline
5
+ from videosys.models.open_sora_plan import OpenSoraPlanPABConfig
6
+
7
+
8
+ def eval_base(prompt_list):
9
+ config = OpenSoraPlanConfig()
10
+ pipeline = OpenSoraPlanPipeline(config)
11
+
12
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_base", loop=5)
13
+
14
+
15
+ def eval_pab1(prompt_list):
16
+ pab_config = OpenSoraPlanPABConfig(
17
+ spatial_gap=2,
18
+ temporal_gap=4,
19
+ cross_gap=6,
20
+ )
21
+ config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
22
+ pipeline = OpenSoraPlanPipeline(config)
23
+
24
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab1", loop=5)
25
+
26
+
27
+ def eval_pab2(prompt_list):
28
+ pab_config = OpenSoraPlanPABConfig(
29
+ spatial_gap=3,
30
+ temporal_gap=5,
31
+ cross_gap=7,
32
+ )
33
+ config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
34
+ pipeline = OpenSoraPlanPipeline(config)
35
+
36
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab2", loop=5)
37
+
38
+
39
+ def eval_pab3(prompt_list):
40
+ pab_config = OpenSoraPlanPABConfig(
41
+ spatial_gap=5,
42
+ temporal_gap=7,
43
+ cross_gap=9,
44
+ )
45
+ config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
46
+ pipeline = OpenSoraPlanPipeline(config)
47
+
48
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab3", loop=5)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ videosys.initialize(42)
53
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
54
+ eval_base(prompt_list)
55
+ eval_pab1(prompt_list)
56
+ eval_pab2(prompt_list)
57
+ eval_pab3(prompt_list)
eval/pab/experiments/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import tqdm
5
+
6
+ from videosys.utils.utils import set_seed
7
+
8
+
9
+ def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
10
+ kwargs["verbose"] = False
11
+ for prompt in tqdm.tqdm(prompt_list):
12
+ for l in range(loop):
13
+ set_seed(l)
14
+ video = pipeline.generate(prompt, **kwargs).video[0]
15
+ pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
16
+
17
+
18
+ def read_prompt_list(prompt_list_path):
19
+ with open(prompt_list_path, "r") as f:
20
+ prompt_list = json.load(f)
21
+ prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
22
+ return prompt_list
eval/pab/vbench/VBench_full_info.json ADDED
The diff for this file is too large to render. See raw diff
 
eval/pab/vbench/cal_vbench.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ SEMANTIC_WEIGHT = 1
6
+ QUALITY_WEIGHT = 4
7
+
8
+ QUALITY_LIST = [
9
+ "subject consistency",
10
+ "background consistency",
11
+ "temporal flickering",
12
+ "motion smoothness",
13
+ "aesthetic quality",
14
+ "imaging quality",
15
+ "dynamic degree",
16
+ ]
17
+
18
+ SEMANTIC_LIST = [
19
+ "object class",
20
+ "multiple objects",
21
+ "human action",
22
+ "color",
23
+ "spatial relationship",
24
+ "scene",
25
+ "appearance style",
26
+ "temporal style",
27
+ "overall consistency",
28
+ ]
29
+
30
+ NORMALIZE_DIC = {
31
+ "subject consistency": {"Min": 0.1462, "Max": 1.0},
32
+ "background consistency": {"Min": 0.2615, "Max": 1.0},
33
+ "temporal flickering": {"Min": 0.6293, "Max": 1.0},
34
+ "motion smoothness": {"Min": 0.706, "Max": 0.9975},
35
+ "dynamic degree": {"Min": 0.0, "Max": 1.0},
36
+ "aesthetic quality": {"Min": 0.0, "Max": 1.0},
37
+ "imaging quality": {"Min": 0.0, "Max": 1.0},
38
+ "object class": {"Min": 0.0, "Max": 1.0},
39
+ "multiple objects": {"Min": 0.0, "Max": 1.0},
40
+ "human action": {"Min": 0.0, "Max": 1.0},
41
+ "color": {"Min": 0.0, "Max": 1.0},
42
+ "spatial relationship": {"Min": 0.0, "Max": 1.0},
43
+ "scene": {"Min": 0.0, "Max": 0.8222},
44
+ "appearance style": {"Min": 0.0009, "Max": 0.2855},
45
+ "temporal style": {"Min": 0.0, "Max": 0.364},
46
+ "overall consistency": {"Min": 0.0, "Max": 0.364},
47
+ }
48
+
49
+ DIM_WEIGHT = {
50
+ "subject consistency": 1,
51
+ "background consistency": 1,
52
+ "temporal flickering": 1,
53
+ "motion smoothness": 1,
54
+ "aesthetic quality": 1,
55
+ "imaging quality": 1,
56
+ "dynamic degree": 0.5,
57
+ "object class": 1,
58
+ "multiple objects": 1,
59
+ "human action": 1,
60
+ "color": 1,
61
+ "spatial relationship": 1,
62
+ "scene": 1,
63
+ "appearance style": 1,
64
+ "temporal style": 1,
65
+ "overall consistency": 1,
66
+ }
67
+
68
+ ordered_scaled_res = [
69
+ "total score",
70
+ "quality score",
71
+ "semantic score",
72
+ "subject consistency",
73
+ "background consistency",
74
+ "temporal flickering",
75
+ "motion smoothness",
76
+ "dynamic degree",
77
+ "aesthetic quality",
78
+ "imaging quality",
79
+ "object class",
80
+ "multiple objects",
81
+ "human action",
82
+ "color",
83
+ "spatial relationship",
84
+ "scene",
85
+ "appearance style",
86
+ "temporal style",
87
+ "overall consistency",
88
+ ]
89
+
90
+
91
+ def parse_args():
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument("--score_dir", required=True, type=str)
94
+ args = parser.parse_args()
95
+ return args
96
+
97
+
98
+ if __name__ == "__main__":
99
+ args = parse_args()
100
+ res_postfix = "_eval_results.json"
101
+ info_postfix = "_full_info.json"
102
+ files = os.listdir(args.score_dir)
103
+ res_files = [x for x in files if res_postfix in x]
104
+ info_files = [x for x in files if info_postfix in x]
105
+ assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
106
+
107
+ full_results = {}
108
+ for res_file in res_files:
109
+ # first check if results is normal
110
+ info_file = res_file.split(res_postfix)[0] + info_postfix
111
+ with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
112
+ info = json.load(f)
113
+ assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
114
+ # read results
115
+ with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
116
+ data = json.load(f)
117
+ for key, val in data.items():
118
+ full_results[key] = format(val[0], ".4f")
119
+
120
+ scaled_results = {}
121
+ dims = set()
122
+ for key, val in full_results.items():
123
+ dim = key.replace("_", " ") if "_" in key else key
124
+ scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
125
+ NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
126
+ )
127
+ scaled_score *= DIM_WEIGHT[dim]
128
+ scaled_results[dim] = scaled_score
129
+ dims.add(dim)
130
+
131
+ assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
132
+
133
+ quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
134
+ semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
135
+ scaled_results["quality score"] = quality_score
136
+ scaled_results["semantic score"] = semantic_score
137
+ scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
138
+ QUALITY_WEIGHT + SEMANTIC_WEIGHT
139
+ )
140
+
141
+ formated_scaled_results = {"items": []}
142
+ for key in ordered_scaled_res:
143
+ formated_score = format(scaled_results[key] * 100, ".2f") + "%"
144
+ formated_scaled_results["items"].append({key: formated_score})
145
+
146
+ output_file_path = os.path.join(args.score_dir, "all_results.json")
147
+ with open(output_file_path, "w") as outfile:
148
+ json.dump(full_results, outfile, indent=4, sort_keys=True)
149
+ print(f"results saved to: {output_file_path}")
150
+
151
+ scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
152
+ with open(scaled_file_path, "w") as outfile:
153
+ json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
154
+ print(f"results saved to: {scaled_file_path}")
eval/pab/vbench/run_vbench.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from vbench import VBench
5
+
6
+ full_info_path = "./vbench/VBench_full_info.json"
7
+
8
+ dimensions = [
9
+ "subject_consistency",
10
+ "imaging_quality",
11
+ "background_consistency",
12
+ "motion_smoothness",
13
+ "overall_consistency",
14
+ "human_action",
15
+ "multiple_objects",
16
+ "spatial_relationship",
17
+ "object_class",
18
+ "color",
19
+ "aesthetic_quality",
20
+ "appearance_style",
21
+ "temporal_flickering",
22
+ "scene",
23
+ "temporal_style",
24
+ "dynamic_degree",
25
+ ]
26
+
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--video_path", required=True, type=str)
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ if __name__ == "__main__":
36
+ args = parse_args()
37
+ save_path = args.video_path.replace("/samples/", "/vbench_out/")
38
+
39
+ kwargs = {}
40
+ kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
41
+
42
+ for dimension in dimensions:
43
+ my_VBench = VBench(torch.device("cuda"), full_info_path, save_path)
44
+ my_VBench.evaluate(
45
+ videos_path=args.video_path,
46
+ name=dimension,
47
+ local=False,
48
+ read_frame=False,
49
+ dimension_list=[dimension],
50
+ mode="vbench_standard",
51
+ **kwargs,
52
+ )
examples/cogvideo/sample.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from videosys import CogVideoConfig, VideoSysEngine
2
+
3
+
4
+ def run_base():
5
+ config = CogVideoConfig(world_size=1)
6
+ engine = VideoSysEngine(config)
7
+
8
+ prompt = "Sunset over the sea."
9
+ video = engine.generate(prompt).video[0]
10
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
11
+
12
+
13
+ if __name__ == "__main__":
14
+ run_base()
examples/latte/sample.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from videosys import LatteConfig, VideoSysEngine
2
+
3
+
4
+ def run_base():
5
+ config = LatteConfig(world_size=1)
6
+ engine = VideoSysEngine(config)
7
+
8
+ prompt = "Sunset over the sea."
9
+ video = engine.generate(prompt).video[0]
10
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
11
+
12
+
13
+ def run_pab():
14
+ config = LatteConfig(world_size=1)
15
+ engine = VideoSysEngine(config)
16
+
17
+ prompt = "Sunset over the sea."
18
+ video = engine.generate(prompt).video[0]
19
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
20
+
21
+
22
+ if __name__ == "__main__":
23
+ run_base()
24
+ # run_pab()
examples/open_sora/sample.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from videosys import OpenSoraConfig, VideoSysEngine
2
+
3
+
4
+ def run_base():
5
+ config = OpenSoraConfig(world_size=1)
6
+ engine = VideoSysEngine(config)
7
+
8
+ prompt = "Sunset over the sea."
9
+ video = engine.generate(prompt).video[0]
10
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
11
+
12
+
13
+ def run_pab():
14
+ config = OpenSoraConfig(world_size=1, enable_pab=True)
15
+ engine = VideoSysEngine(config)
16
+
17
+ prompt = "Sunset over the sea."
18
+ video = engine.generate(prompt).video[0]
19
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
20
+
21
+
22
+ if __name__ == "__main__":
23
+ run_base()
24
+ run_pab()
examples/open_sora_plan/sample.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from videosys import OpenSoraPlanConfig, VideoSysEngine
2
+
3
+
4
+ def run_base():
5
+ config = OpenSoraPlanConfig(world_size=1)
6
+ engine = VideoSysEngine(config)
7
+
8
+ prompt = "Sunset over the sea."
9
+ video = engine.generate(prompt).video[0]
10
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
11
+
12
+
13
+ def run_pab():
14
+ config = OpenSoraPlanConfig(world_size=1)
15
+ engine = VideoSysEngine(config)
16
+
17
+ prompt = "Sunset over the sea."
18
+ video = engine.generate(prompt).video[0]
19
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
20
+
21
+
22
+ if __name__ == "__main__":
23
+ run_base()
24
+ # run_pab()
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ click
3
+ colossalai
4
+ contexttimer
5
+ diffusers==0.30.0
6
+ einops
7
+ fabric
8
+ ftfy
9
+ imageio
10
+ imageio-ffmpeg
11
+ matplotlib
12
+ ninja
13
+ numpy
14
+ omegaconf
15
+ packaging
16
+ psutil
17
+ pydantic
18
+ ray
19
+ rich
20
+ safetensors
21
+ timm
22
+ torch>=1.13
23
+ tqdm
24
+ transformers
25
+ openai
setup.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+
6
+ def fetch_requirements(path) -> List[str]:
7
+ """
8
+ This function reads the requirements file.
9
+
10
+ Args:
11
+ path (str): the path to the requirements file.
12
+
13
+ Returns:
14
+ The lines in the requirements file.
15
+ """
16
+ with open(path, "r") as fd:
17
+ return [r.strip() for r in fd.readlines()]
18
+
19
+
20
+ def fetch_readme() -> str:
21
+ """
22
+ This function reads the README.md file in the current directory.
23
+
24
+ Returns:
25
+ The lines in the README file.
26
+ """
27
+ with open("README.md", encoding="utf-8") as f:
28
+ return f.read()
29
+
30
+
31
+ setup(
32
+ name="videosys",
33
+ version="2.0.0",
34
+ packages=find_packages(
35
+ exclude=(
36
+ "videos",
37
+ "tests",
38
+ "figure",
39
+ "*.egg-info",
40
+ )
41
+ ),
42
+ description="VideoSys",
43
+ long_description=fetch_readme(),
44
+ long_description_content_type="text/markdown",
45
+ license="Apache Software License 2.0",
46
+ install_requires=fetch_requirements("requirements.txt"),
47
+ python_requires=">=3.6",
48
+ classifiers=[
49
+ "Programming Language :: Python :: 3",
50
+ "License :: OSI Approved :: Apache Software License",
51
+ "Environment :: GPU :: NVIDIA CUDA",
52
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
53
+ "Topic :: System :: Distributed Computing",
54
+ ],
55
+ )
tests/__init__.py ADDED
File without changes
videosys/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .core.engine import VideoSysEngine
2
+ from .core.parallel_mgr import initialize
3
+ from .models.cogvideo.pipeline import CogVideoConfig, CogVideoPipeline
4
+ from .models.latte.pipeline import LatteConfig, LattePipeline
5
+ from .models.open_sora.pipeline import OpenSoraConfig, OpenSoraPipeline
6
+ from .models.open_sora_plan.pipeline import OpenSoraPlanConfig, OpenSoraPlanPipeline
7
+
8
+ __all__ = [
9
+ "initialize",
10
+ "VideoSysEngine",
11
+ "LattePipeline",
12
+ "LatteConfig",
13
+ "OpenSoraPlanPipeline",
14
+ "OpenSoraPlanConfig",
15
+ "OpenSoraPipeline",
16
+ "OpenSoraConfig",
17
+ "CogVideoConfig",
18
+ "CogVideoPipeline",
19
+ ]
videosys/core/__init__.py ADDED
File without changes
videosys/core/comm.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from torch import Tensor
8
+ from torch.distributed import ProcessGroup
9
+
10
+ from videosys.core.parallel_mgr import get_sequence_parallel_size
11
+
12
+ # ======================================================
13
+ # Model
14
+ # ======================================================
15
+
16
+
17
+ def model_sharding(model: torch.nn.Module):
18
+ global_rank = dist.get_rank()
19
+ world_size = dist.get_world_size()
20
+ for _, param in model.named_parameters():
21
+ padding_size = (world_size - param.numel() % world_size) % world_size
22
+ if padding_size > 0:
23
+ padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
24
+ else:
25
+ padding_param = param.data.view(-1)
26
+ splited_params = padding_param.split(padding_param.numel() // world_size)
27
+ splited_params = splited_params[global_rank]
28
+ param.data = splited_params
29
+
30
+
31
+ # ======================================================
32
+ # AllGather & ReduceScatter
33
+ # ======================================================
34
+
35
+
36
+ class AsyncAllGatherForTwo(torch.autograd.Function):
37
+ @staticmethod
38
+ def forward(
39
+ ctx: Any,
40
+ inputs: Tensor,
41
+ weight: Tensor,
42
+ bias: Tensor,
43
+ sp_rank: int,
44
+ sp_size: int,
45
+ group: Optional[ProcessGroup] = None,
46
+ ) -> Tuple[Tensor, Any]:
47
+ """
48
+ Returns:
49
+ outputs: Tensor
50
+ handle: Optional[Work], if overlap is True
51
+ """
52
+ from torch.distributed._functional_collectives import all_gather_tensor
53
+
54
+ ctx.group = group
55
+ ctx.sp_rank = sp_rank
56
+ ctx.sp_size = sp_size
57
+
58
+ # all gather inputs
59
+ all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
60
+ # compute local qkv
61
+ local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
62
+
63
+ # remote compute
64
+ remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
65
+ # compute remote qkv
66
+ remote_qkv = F.linear(remote_inputs, weight, bias)
67
+
68
+ # concat local and remote qkv
69
+ if sp_rank == 0:
70
+ qkv = torch.cat([local_qkv, remote_qkv], dim=0)
71
+ else:
72
+ qkv = torch.cat([remote_qkv, local_qkv], dim=0)
73
+ qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
74
+
75
+ ctx.save_for_backward(inputs, weight, remote_inputs)
76
+ return qkv
77
+
78
+ @staticmethod
79
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
80
+ from torch.distributed._functional_collectives import reduce_scatter_tensor
81
+
82
+ group = ctx.group
83
+ sp_rank = ctx.sp_rank
84
+ sp_size = ctx.sp_size
85
+ inputs, weight, remote_inputs = ctx.saved_tensors
86
+
87
+ # split qkv_grad
88
+ qkv_grad = grad_outputs[0]
89
+ qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
90
+ qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
91
+ if sp_rank == 0:
92
+ local_qkv_grad, remote_qkv_grad = qkv_grad
93
+ else:
94
+ remote_qkv_grad, local_qkv_grad = qkv_grad
95
+
96
+ # compute remote grad
97
+ remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
98
+ weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
99
+ bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
100
+
101
+ # launch async reduce scatter
102
+ remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
103
+ if sp_rank == 0:
104
+ remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
105
+ else:
106
+ remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
107
+ remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
108
+
109
+ # compute local grad and wait for reduce scatter
110
+ local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
111
+ weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
112
+ bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
113
+
114
+ # sum remote and local grad
115
+ inputs_grad = remote_inputs_grad + local_input_grad
116
+ return inputs_grad, weight_grad, bias_grad, None, None, None
117
+
118
+
119
+ class AllGather(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(
122
+ ctx: Any,
123
+ inputs: Tensor,
124
+ group: Optional[ProcessGroup] = None,
125
+ overlap: bool = False,
126
+ ) -> Tuple[Tensor, Any]:
127
+ """
128
+ Returns:
129
+ outputs: Tensor
130
+ handle: Optional[Work], if overlap is True
131
+ """
132
+ assert ctx is not None or not overlap
133
+
134
+ if ctx is not None:
135
+ ctx.comm_grp = group
136
+
137
+ comm_size = dist.get_world_size(group)
138
+ if comm_size == 1:
139
+ return inputs.unsqueeze(0), None
140
+
141
+ buffer_shape = (comm_size,) + inputs.shape
142
+ outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
143
+ buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
144
+ if not overlap:
145
+ dist.all_gather(buffer_list, inputs, group=group)
146
+ return outputs, None
147
+ else:
148
+ handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
149
+ return outputs, handle
150
+
151
+ @staticmethod
152
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
153
+ return (
154
+ ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
155
+ None,
156
+ None,
157
+ )
158
+
159
+
160
+ class ReduceScatter(torch.autograd.Function):
161
+ @staticmethod
162
+ def forward(
163
+ ctx: Any,
164
+ inputs: Tensor,
165
+ group: ProcessGroup,
166
+ overlap: bool = False,
167
+ ) -> Tuple[Tensor, Any]:
168
+ """
169
+ Returns:
170
+ outputs: Tensor
171
+ handle: Optional[Work], if overlap is True
172
+ """
173
+ assert ctx is not None or not overlap
174
+
175
+ if ctx is not None:
176
+ ctx.comm_grp = group
177
+
178
+ comm_size = dist.get_world_size(group)
179
+ if comm_size == 1:
180
+ return inputs.squeeze(0), None
181
+
182
+ if not inputs.is_contiguous():
183
+ inputs = inputs.contiguous()
184
+
185
+ output_shape = inputs.shape[1:]
186
+ outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
187
+ buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
188
+ if not overlap:
189
+ dist.reduce_scatter(outputs, buffer_list, group=group)
190
+ return outputs, None
191
+ else:
192
+ handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
193
+ return outputs, handle
194
+
195
+ @staticmethod
196
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
197
+ # TODO: support async backward
198
+ return (
199
+ AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
200
+ None,
201
+ None,
202
+ )
203
+
204
+
205
+ # ======================================================
206
+ # AlltoAll
207
+ # ======================================================
208
+
209
+
210
+ def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
211
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
212
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
213
+ dist.all_to_all(output_list, input_list, group=group)
214
+ return torch.cat(output_list, dim=gather_dim).contiguous()
215
+
216
+
217
+ class _AllToAll(torch.autograd.Function):
218
+ """All-to-all communication.
219
+
220
+ Args:
221
+ input_: input matrix
222
+ process_group: communication group
223
+ scatter_dim: scatter dimension
224
+ gather_dim: gather dimension
225
+ """
226
+
227
+ @staticmethod
228
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
229
+ ctx.process_group = process_group
230
+ ctx.scatter_dim = scatter_dim
231
+ ctx.gather_dim = gather_dim
232
+ world_size = dist.get_world_size(process_group)
233
+
234
+ return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
235
+
236
+ @staticmethod
237
+ def backward(ctx, *grad_output):
238
+ process_group = ctx.process_group
239
+ scatter_dim = ctx.gather_dim
240
+ gather_dim = ctx.scatter_dim
241
+ return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
242
+ return (return_grad, None, None, None)
243
+
244
+
245
+ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
246
+ return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
247
+
248
+
249
+ # ======================================================
250
+ # Sequence Gather & Split
251
+ # ======================================================
252
+
253
+
254
+ def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
255
+ # skip if only one rank involved
256
+ world_size = dist.get_world_size(pg)
257
+ rank = dist.get_rank(pg)
258
+ if world_size == 1:
259
+ return input_
260
+
261
+ if pad > 0:
262
+ pad_size = list(input_.shape)
263
+ pad_size[dim] = pad
264
+ input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
265
+
266
+ dim_size = input_.size(dim)
267
+ assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
268
+
269
+ tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
270
+ output = tensor_list[rank].contiguous()
271
+ return output
272
+
273
+
274
+ def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
275
+ # skip if only one rank involved
276
+ input_ = input_.contiguous()
277
+ world_size = dist.get_world_size(pg)
278
+ dist.get_rank(pg)
279
+
280
+ if world_size == 1:
281
+ return input_
282
+
283
+ # all gather
284
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
285
+ assert input_.device.type == "cuda"
286
+ torch.distributed.all_gather(tensor_list, input_, group=pg)
287
+
288
+ # concat
289
+ output = torch.cat(tensor_list, dim=dim)
290
+
291
+ if pad > 0:
292
+ output = output.narrow(dim, 0, output.size(dim) - pad)
293
+
294
+ return output
295
+
296
+
297
+ class _GatherForwardSplitBackward(torch.autograd.Function):
298
+ """
299
+ Gather the input sequence.
300
+
301
+ Args:
302
+ input_: input matrix.
303
+ process_group: process group.
304
+ dim: dimension
305
+ """
306
+
307
+ @staticmethod
308
+ def symbolic(graph, input_):
309
+ return _gather_sequence_func(input_)
310
+
311
+ @staticmethod
312
+ def forward(ctx, input_, process_group, dim, grad_scale, pad):
313
+ ctx.process_group = process_group
314
+ ctx.dim = dim
315
+ ctx.grad_scale = grad_scale
316
+ ctx.pad = pad
317
+ return _gather_sequence_func(input_, process_group, dim, pad)
318
+
319
+ @staticmethod
320
+ def backward(ctx, grad_output):
321
+ if ctx.grad_scale == "up":
322
+ grad_output = grad_output * dist.get_world_size(ctx.process_group)
323
+ elif ctx.grad_scale == "down":
324
+ grad_output = grad_output / dist.get_world_size(ctx.process_group)
325
+
326
+ return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
327
+
328
+
329
+ class _SplitForwardGatherBackward(torch.autograd.Function):
330
+ """
331
+ Split sequence.
332
+
333
+ Args:
334
+ input_: input matrix.
335
+ process_group: parallel mode.
336
+ dim: dimension
337
+ """
338
+
339
+ @staticmethod
340
+ def symbolic(graph, input_):
341
+ return _split_sequence_func(input_)
342
+
343
+ @staticmethod
344
+ def forward(ctx, input_, process_group, dim, grad_scale, pad):
345
+ ctx.process_group = process_group
346
+ ctx.dim = dim
347
+ ctx.grad_scale = grad_scale
348
+ ctx.pad = pad
349
+ return _split_sequence_func(input_, process_group, dim, pad)
350
+
351
+ @staticmethod
352
+ def backward(ctx, grad_output):
353
+ if ctx.grad_scale == "up":
354
+ grad_output = grad_output * dist.get_world_size(ctx.process_group)
355
+ elif ctx.grad_scale == "down":
356
+ grad_output = grad_output / dist.get_world_size(ctx.process_group)
357
+ return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
358
+
359
+
360
+ def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
361
+ return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
362
+
363
+
364
+ def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
365
+ return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
366
+
367
+
368
+ # ==============================
369
+ # Pad
370
+ # ==============================
371
+
372
+ SPTIAL_PAD = 0
373
+ TEMPORAL_PAD = 0
374
+
375
+
376
+ def set_spatial_pad(dim_size: int):
377
+ sp_size = get_sequence_parallel_size()
378
+ pad = (sp_size - (dim_size % sp_size)) % sp_size
379
+ global SPTIAL_PAD
380
+ SPTIAL_PAD = pad
381
+
382
+
383
+ def get_spatial_pad() -> int:
384
+ return SPTIAL_PAD
385
+
386
+
387
+ def set_temporal_pad(dim_size: int):
388
+ sp_size = get_sequence_parallel_size()
389
+ pad = (sp_size - (dim_size % sp_size)) % sp_size
390
+ global TEMPORAL_PAD
391
+ TEMPORAL_PAD = pad
392
+
393
+
394
+ def get_temporal_pad() -> int:
395
+ return TEMPORAL_PAD
396
+
397
+
398
+ def all_to_all_with_pad(
399
+ input_: torch.Tensor,
400
+ process_group: dist.ProcessGroup,
401
+ scatter_dim: int = 2,
402
+ gather_dim: int = 1,
403
+ scatter_pad: int = 0,
404
+ gather_pad: int = 0,
405
+ ):
406
+ if scatter_pad > 0:
407
+ pad_shape = list(input_.shape)
408
+ pad_shape[scatter_dim] = scatter_pad
409
+ pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
410
+ input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
411
+
412
+ assert (
413
+ input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
414
+ ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
415
+ input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
416
+
417
+ if gather_pad > 0:
418
+ input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
419
+
420
+ return input_
videosys/core/engine.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ from typing import Any, Optional
4
+
5
+ import imageio
6
+ import torch
7
+
8
+ import videosys
9
+
10
+ from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
11
+
12
+
13
+ class VideoSysEngine:
14
+ """
15
+ this is partly inspired by vllm
16
+ """
17
+
18
+ def __init__(self, config):
19
+ self.config = config
20
+ self.parallel_worker_tasks = None
21
+ self._init_worker(config.pipeline_cls)
22
+
23
+ def _init_worker(self, pipeline_cls):
24
+ world_size = self.config.world_size
25
+
26
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))
28
+
29
+ # Disable torch async compiling which won't work with daemonic processes
30
+ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
31
+
32
+ # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
33
+ # contention amongst the shards
34
+ if "OMP_NUM_THREADS" not in os.environ:
35
+ os.environ["OMP_NUM_THREADS"] = "1"
36
+
37
+ # NOTE: The two following lines need adaption for multi-node
38
+ assert world_size <= torch.cuda.device_count()
39
+
40
+ # change addr for multi-node
41
+ distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
42
+
43
+ if world_size == 1:
44
+ self.workers = []
45
+ self.worker_monitor = None
46
+ else:
47
+ result_handler = ResultHandler()
48
+ self.workers = [
49
+ ProcessWorkerWrapper(
50
+ result_handler,
51
+ partial(
52
+ self._create_pipeline,
53
+ pipeline_cls=pipeline_cls,
54
+ rank=rank,
55
+ local_rank=rank,
56
+ distributed_init_method=distributed_init_method,
57
+ ),
58
+ )
59
+ for rank in range(1, world_size)
60
+ ]
61
+
62
+ self.worker_monitor = WorkerMonitor(self.workers, result_handler)
63
+ result_handler.start()
64
+ self.worker_monitor.start()
65
+
66
+ self.driver_worker = self._create_pipeline(
67
+ pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
68
+ )
69
+
70
+ # TODO: add more options here for pipeline, or wrap all options into config
71
+ def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
72
+ videosys.initialize(rank=rank, world_size=self.config.world_size, init_method=distributed_init_method, seed=42)
73
+
74
+ pipeline = pipeline_cls(self.config)
75
+ return pipeline
76
+
77
+ def _run_workers(
78
+ self,
79
+ method: str,
80
+ *args,
81
+ async_run_tensor_parallel_workers_only: bool = False,
82
+ max_concurrent_workers: Optional[int] = None,
83
+ **kwargs,
84
+ ) -> Any:
85
+ """Runs the given method on all workers."""
86
+
87
+ # Start the workers first.
88
+ worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
89
+
90
+ if async_run_tensor_parallel_workers_only:
91
+ # Just return futures
92
+ return worker_outputs
93
+
94
+ driver_worker_method = getattr(self.driver_worker, method)
95
+ driver_worker_output = driver_worker_method(*args, **kwargs)
96
+
97
+ # Get the results of the workers.
98
+ return [driver_worker_output] + [output.get() for output in worker_outputs]
99
+
100
+ def _driver_execute_model(self, *args, **kwargs):
101
+ return self.driver_worker.generate(*args, **kwargs)
102
+
103
+ def generate(self, *args, **kwargs):
104
+ return self._run_workers("generate", *args, **kwargs)[0]
105
+
106
+ def stop_remote_worker_execution_loop(self) -> None:
107
+ if self.parallel_worker_tasks is None:
108
+ return
109
+
110
+ parallel_worker_tasks = self.parallel_worker_tasks
111
+ self.parallel_worker_tasks = None
112
+ # Ensure that workers exit model loop cleanly
113
+ # (this will raise otherwise)
114
+ self._wait_for_tasks_completion(parallel_worker_tasks)
115
+
116
+ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
117
+ """Wait for futures returned from _run_workers() with
118
+ async_run_remote_workers_only to complete."""
119
+ for result in parallel_worker_tasks:
120
+ result.get()
121
+
122
+ def save_video(self, video, output_path):
123
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
124
+ imageio.mimwrite(output_path, video, fps=24)
125
+
126
+ def shutdown(self):
127
+ if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
128
+ worker_monitor.close()
129
+ torch.distributed.destroy_process_group()
130
+
131
+ def __del__(self):
132
+ self.shutdown()
videosys/core/mp_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from vllm
2
+ # https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py
3
+
4
+ import asyncio
5
+ import multiprocessing
6
+ import os
7
+ import socket
8
+ import sys
9
+ import threading
10
+ import traceback
11
+ import uuid
12
+ from dataclasses import dataclass
13
+ from multiprocessing import Queue
14
+ from multiprocessing.connection import wait
15
+ from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union
16
+
17
+ from videosys.utils.logging import create_logger
18
+
19
+ T = TypeVar("T")
20
+ _TERMINATE = "TERMINATE" # sentinel
21
+ # ANSI color codes
22
+ CYAN = "\033[1;36m"
23
+ RESET = "\033[0;0m"
24
+ JOIN_TIMEOUT_S = 2
25
+
26
+ mp_method = "spawn" # fork cann't work
27
+ mp = multiprocessing.get_context(mp_method)
28
+
29
+ logger = create_logger()
30
+
31
+
32
+ def get_distributed_init_method(ip: str, port: int) -> str:
33
+ # Brackets are not permitted in ipv4 addresses,
34
+ # see https://github.com/python/cpython/issues/103848
35
+ return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
36
+
37
+
38
+ def get_open_port() -> int:
39
+ # try ipv4
40
+ try:
41
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
42
+ s.bind(("", 0))
43
+ return s.getsockname()[1]
44
+ except OSError:
45
+ # try ipv6
46
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
47
+ s.bind(("", 0))
48
+ return s.getsockname()[1]
49
+
50
+
51
+ @dataclass
52
+ class Result(Generic[T]):
53
+ """Result of task dispatched to worker"""
54
+
55
+ task_id: uuid.UUID
56
+ value: Optional[T] = None
57
+ exception: Optional[BaseException] = None
58
+
59
+
60
+ class ResultFuture(threading.Event, Generic[T]):
61
+ """Synchronous future for non-async case"""
62
+
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.result: Optional[Result[T]] = None
66
+
67
+ def set_result(self, result: Result[T]):
68
+ self.result = result
69
+ self.set()
70
+
71
+ def get(self) -> T:
72
+ self.wait()
73
+ assert self.result is not None
74
+ if self.result.exception is not None:
75
+ raise self.result.exception
76
+ return self.result.value # type: ignore[return-value]
77
+
78
+
79
+ def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result):
80
+ if isinstance(future, ResultFuture):
81
+ future.set_result(result)
82
+ return
83
+ loop = future.get_loop()
84
+ if not loop.is_closed():
85
+ if result.exception is not None:
86
+ loop.call_soon_threadsafe(future.set_exception, result.exception)
87
+ else:
88
+ loop.call_soon_threadsafe(future.set_result, result.value)
89
+
90
+
91
+ class ResultHandler(threading.Thread):
92
+ """Handle results from all workers (in background thread)"""
93
+
94
+ def __init__(self) -> None:
95
+ super().__init__(daemon=True)
96
+ self.result_queue = mp.Queue()
97
+ self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
98
+
99
+ def run(self):
100
+ for result in iter(self.result_queue.get, _TERMINATE):
101
+ future = self.tasks.pop(result.task_id)
102
+ _set_future_result(future, result)
103
+ # Ensure that all waiters will receive an exception
104
+ for task_id, future in self.tasks.items():
105
+ _set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died")))
106
+
107
+ def close(self):
108
+ self.result_queue.put(_TERMINATE)
109
+
110
+
111
+ class WorkerMonitor(threading.Thread):
112
+ """Monitor worker status (in background thread)"""
113
+
114
+ def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler):
115
+ super().__init__(daemon=True)
116
+ self.workers = workers
117
+ self.result_handler = result_handler
118
+ self._close = False
119
+
120
+ def run(self) -> None:
121
+ # Blocks until any worker exits
122
+ dead_sentinels = wait([w.process.sentinel for w in self.workers])
123
+ if not self._close:
124
+ self._close = True
125
+
126
+ # Kill / cleanup all workers
127
+ for worker in self.workers:
128
+ process = worker.process
129
+ if process.sentinel in dead_sentinels:
130
+ process.join(JOIN_TIMEOUT_S)
131
+ if process.exitcode is not None and process.exitcode != 0:
132
+ logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode)
133
+ # Cleanup any remaining workers
134
+ logger.info("Killing local worker processes")
135
+ for worker in self.workers:
136
+ worker.kill_worker()
137
+ # Must be done after worker task queues are all closed
138
+ self.result_handler.close()
139
+
140
+ for worker in self.workers:
141
+ worker.process.join(JOIN_TIMEOUT_S)
142
+
143
+ def close(self):
144
+ if self._close:
145
+ return
146
+ self._close = True
147
+ logger.info("Terminating local worker processes")
148
+ for worker in self.workers:
149
+ worker.terminate_worker()
150
+ # Must be done after worker task queues are all closed
151
+ self.result_handler.close()
152
+
153
+
154
+ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
155
+ """Prepend each output line with process-specific prefix"""
156
+
157
+ prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
158
+ file_write = file.write
159
+
160
+ def write_with_prefix(s: str):
161
+ if not s:
162
+ return
163
+ if file.start_new_line: # type: ignore[attr-defined]
164
+ file_write(prefix)
165
+ idx = 0
166
+ while (next_idx := s.find("\n", idx)) != -1:
167
+ next_idx += 1
168
+ file_write(s[idx:next_idx])
169
+ if next_idx == len(s):
170
+ file.start_new_line = True # type: ignore[attr-defined]
171
+ return
172
+ file_write(prefix)
173
+ idx = next_idx
174
+ file_write(s[idx:])
175
+ file.start_new_line = False # type: ignore[attr-defined]
176
+
177
+ file.start_new_line = True # type: ignore[attr-defined]
178
+ file.write = write_with_prefix # type: ignore[method-assign]
179
+
180
+
181
+ def _run_worker_process(
182
+ worker_factory: Callable[[], Any],
183
+ task_queue: Queue,
184
+ result_queue: Queue,
185
+ ) -> None:
186
+ """Worker process event loop"""
187
+
188
+ # Add process-specific prefix to stdout and stderr
189
+ process_name = mp.current_process().name
190
+ pid = os.getpid()
191
+ _add_prefix(sys.stdout, process_name, pid)
192
+ _add_prefix(sys.stderr, process_name, pid)
193
+
194
+ # Initialize worker
195
+ worker = worker_factory()
196
+ del worker_factory
197
+
198
+ # Accept tasks from the engine in task_queue
199
+ # and return task output in result_queue
200
+ logger.info("Worker ready; awaiting tasks")
201
+ try:
202
+ for items in iter(task_queue.get, _TERMINATE):
203
+ output = None
204
+ exception = None
205
+ task_id, method, args, kwargs = items
206
+ try:
207
+ executor = getattr(worker, method)
208
+ output = executor(*args, **kwargs)
209
+ except BaseException as e:
210
+ tb = traceback.format_exc()
211
+ logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb)
212
+ exception = e
213
+ result_queue.put(Result(task_id=task_id, value=output, exception=exception))
214
+ except KeyboardInterrupt:
215
+ pass
216
+ except Exception:
217
+ logger.exception("Worker failed")
218
+
219
+ logger.info("Worker exiting")
220
+
221
+
222
+ class ProcessWorkerWrapper:
223
+ """Local process wrapper for handling single-node multi-GPU."""
224
+
225
+ def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None:
226
+ self._task_queue = mp.Queue()
227
+ self.result_queue = result_handler.result_queue
228
+ self.tasks = result_handler.tasks
229
+ self.process = mp.Process( # type: ignore[attr-defined]
230
+ target=_run_worker_process,
231
+ name="VideoSysWorkerProcess",
232
+ kwargs=dict(
233
+ worker_factory=worker_factory,
234
+ task_queue=self._task_queue,
235
+ result_queue=self.result_queue,
236
+ ),
237
+ daemon=True,
238
+ )
239
+
240
+ self.process.start()
241
+
242
+ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs):
243
+ task_id = uuid.uuid4()
244
+ self.tasks[task_id] = future
245
+ try:
246
+ self._task_queue.put((task_id, method, args, kwargs))
247
+ except BaseException as e:
248
+ del self.tasks[task_id]
249
+ raise ChildProcessError("worker died") from e
250
+
251
+ def execute_method(self, method: str, *args, **kwargs):
252
+ future: ResultFuture = ResultFuture()
253
+ self._enqueue_task(future, method, args, kwargs)
254
+ return future
255
+
256
+ async def execute_method_async(self, method: str, *args, **kwargs):
257
+ future = asyncio.get_running_loop().create_future()
258
+ self._enqueue_task(future, method, args, kwargs)
259
+ return await future
260
+
261
+ def terminate_worker(self):
262
+ try:
263
+ self._task_queue.put(_TERMINATE)
264
+ except ValueError:
265
+ self.process.kill()
266
+ self._task_queue.close()
267
+
268
+ def kill_worker(self):
269
+ self._task_queue.close()
270
+ self.process.kill()
videosys/core/pab_mgr.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from videosys.utils.logging import logger
7
+
8
+ PAB_MANAGER = None
9
+
10
+
11
+ class PABConfig:
12
+ def __init__(
13
+ self,
14
+ steps: int,
15
+ cross_broadcast: bool,
16
+ cross_threshold: list,
17
+ cross_gap: int,
18
+ spatial_broadcast: bool,
19
+ spatial_threshold: list,
20
+ spatial_gap: int,
21
+ temporal_broadcast: bool,
22
+ temporal_threshold: list,
23
+ temporal_gap: int,
24
+ diffusion_skip: bool,
25
+ diffusion_timestep_respacing: list,
26
+ diffusion_skip_timestep: list,
27
+ mlp_skip: bool,
28
+ mlp_spatial_skip_config: dict,
29
+ mlp_temporal_skip_config: dict,
30
+ full_broadcast: bool = False,
31
+ full_threshold: list = None,
32
+ full_gap: int = 1,
33
+ ):
34
+ self.steps = steps
35
+
36
+ self.cross_broadcast = cross_broadcast
37
+ self.cross_threshold = cross_threshold
38
+ self.cross_gap = cross_gap
39
+
40
+ self.spatial_broadcast = spatial_broadcast
41
+ self.spatial_threshold = spatial_threshold
42
+ self.spatial_gap = spatial_gap
43
+
44
+ self.temporal_broadcast = temporal_broadcast
45
+ self.temporal_threshold = temporal_threshold
46
+ self.temporal_gap = temporal_gap
47
+
48
+ self.diffusion_skip = diffusion_skip
49
+ self.diffusion_timestep_respacing = diffusion_timestep_respacing
50
+ self.diffusion_skip_timestep = diffusion_skip_timestep
51
+
52
+ self.mlp_skip = mlp_skip
53
+ self.mlp_spatial_skip_config = mlp_spatial_skip_config
54
+ self.mlp_temporal_skip_config = mlp_temporal_skip_config
55
+
56
+ self.temporal_mlp_outputs = {}
57
+ self.spatial_mlp_outputs = {}
58
+
59
+ self.full_broadcast = full_broadcast
60
+ self.full_threshold = full_threshold
61
+ self.full_gap = full_gap
62
+
63
+
64
+ class PABManager:
65
+ def __init__(self, config: PABConfig):
66
+ self.config: PABConfig = config
67
+
68
+ init_prompt = f"Init PABManager. steps: {config.steps}."
69
+ init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
70
+ init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
71
+ init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
72
+ init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
73
+ logger.info(init_prompt)
74
+
75
+ def if_broadcast_cross(self, timestep: int, count: int):
76
+ if (
77
+ self.config.cross_broadcast
78
+ and (timestep is not None)
79
+ and (count % self.config.cross_gap != 0)
80
+ and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
81
+ ):
82
+ flag = True
83
+ else:
84
+ flag = False
85
+ count = (count + 1) % self.config.steps
86
+ return flag, count
87
+
88
+ def if_broadcast_temporal(self, timestep: int, count: int):
89
+ if (
90
+ self.config.temporal_broadcast
91
+ and (timestep is not None)
92
+ and (count % self.config.temporal_gap != 0)
93
+ and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
94
+ ):
95
+ flag = True
96
+ else:
97
+ flag = False
98
+ count = (count + 1) % self.config.steps
99
+ return flag, count
100
+
101
+ def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
102
+ if (
103
+ self.config.spatial_broadcast
104
+ and (timestep is not None)
105
+ and (count % self.config.spatial_gap != 0)
106
+ and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
107
+ ):
108
+ flag = True
109
+ else:
110
+ flag = False
111
+ count = (count + 1) % self.config.steps
112
+ return flag, count
113
+
114
+ def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
115
+ if (
116
+ self.config.full_broadcast
117
+ and (timestep is not None)
118
+ and (count % self.config.full_gap != 0)
119
+ and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
120
+ ):
121
+ flag = True
122
+ else:
123
+ flag = False
124
+ count = (count + 1) % self.config.steps
125
+ return flag, count
126
+
127
+ @staticmethod
128
+ def _is_t_in_skip_config(all_timesteps, timestep, config):
129
+ is_t_in_skip_config = False
130
+ for key in config:
131
+ if key not in all_timesteps:
132
+ continue
133
+ index = all_timesteps.index(key)
134
+ skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
135
+ if timestep in skip_range:
136
+ is_t_in_skip_config = True
137
+ skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
138
+ break
139
+ return is_t_in_skip_config, skip_range
140
+
141
+ def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
142
+ if not self.config.mlp_skip:
143
+ return False, None, False, None
144
+
145
+ if is_temporal:
146
+ cur_config = self.config.mlp_temporal_skip_config
147
+ else:
148
+ cur_config = self.config.mlp_spatial_skip_config
149
+
150
+ is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
151
+ next_flag = False
152
+ if (
153
+ self.config.mlp_skip
154
+ and (timestep is not None)
155
+ and (timestep in cur_config)
156
+ and (block_idx in cur_config[timestep]["block"])
157
+ ):
158
+ flag = False
159
+ next_flag = True
160
+ count = count + 1
161
+ elif (
162
+ self.config.mlp_skip
163
+ and (timestep is not None)
164
+ and (is_t_in_skip_config)
165
+ and (block_idx in cur_config[skip_range[0]]["block"])
166
+ ):
167
+ flag = True
168
+ count = 0
169
+ else:
170
+ flag = False
171
+
172
+ return flag, count, next_flag, skip_range
173
+
174
+ def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
175
+ if is_temporal:
176
+ self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
177
+ else:
178
+ self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output
179
+
180
+ def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
181
+ skip_start_t = skip_range[0]
182
+ if is_temporal:
183
+ skip_output = (
184
+ self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
185
+ if self.config.temporal_mlp_outputs is not None
186
+ else None
187
+ )
188
+ else:
189
+ skip_output = (
190
+ self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
191
+ if self.config.spatial_mlp_outputs is not None
192
+ else None
193
+ )
194
+
195
+ if skip_output is not None:
196
+ if timestep == skip_range[-1]:
197
+ # TODO: save memory
198
+ if is_temporal:
199
+ del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
200
+ else:
201
+ del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
202
+ else:
203
+ raise ValueError(
204
+ f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
205
+ )
206
+
207
+ return skip_output
208
+
209
+ def get_spatial_mlp_outputs(self):
210
+ return self.config.spatial_mlp_outputs
211
+
212
+ def get_temporal_mlp_outputs(self):
213
+ return self.config.temporal_mlp_outputs
214
+
215
+
216
+ def set_pab_manager(config: PABConfig):
217
+ global PAB_MANAGER
218
+ PAB_MANAGER = PABManager(config)
219
+
220
+
221
+ def enable_pab():
222
+ if PAB_MANAGER is None:
223
+ return False
224
+ return (
225
+ PAB_MANAGER.config.cross_broadcast
226
+ or PAB_MANAGER.config.spatial_broadcast
227
+ or PAB_MANAGER.config.temporal_broadcast
228
+ )
229
+
230
+
231
+ def update_steps(steps: int):
232
+ if PAB_MANAGER is not None:
233
+ PAB_MANAGER.config.steps = steps
234
+
235
+
236
+ def if_broadcast_cross(timestep: int, count: int):
237
+ if not enable_pab():
238
+ return False, count
239
+ return PAB_MANAGER.if_broadcast_cross(timestep, count)
240
+
241
+
242
+ def if_broadcast_temporal(timestep: int, count: int):
243
+ if not enable_pab():
244
+ return False, count
245
+ return PAB_MANAGER.if_broadcast_temporal(timestep, count)
246
+
247
+
248
+ def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
249
+ if not enable_pab():
250
+ return False, count
251
+ return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
252
+
253
+ def if_broadcast_full(timestep: int, count: int, block_idx: int):
254
+ if not enable_pab():
255
+ return False, count
256
+ return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
257
+
258
+
259
+ def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
260
+ if not enable_pab():
261
+ return False, count
262
+ return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
263
+
264
+
265
+ def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
266
+ return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
267
+
268
+
269
+ def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
270
+ return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
271
+
272
+
273
+ def get_diffusion_skip():
274
+ return enable_pab() and PAB_MANAGER.config.diffusion_skip
275
+
276
+
277
+ def get_diffusion_timestep_respacing():
278
+ return PAB_MANAGER.config.diffusion_timestep_respacing
279
+
280
+
281
+ def get_diffusion_skip_timestep():
282
+ return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
283
+
284
+
285
+ def space_timesteps(time_steps, time_bins):
286
+ num_bins = len(time_bins)
287
+ bin_size = time_steps // num_bins
288
+
289
+ result = []
290
+
291
+ for i, bin_count in enumerate(time_bins):
292
+ start = i * bin_size
293
+ end = start + bin_size
294
+
295
+ bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
296
+ result.extend(bin_steps)
297
+
298
+ result_tensor = torch.tensor(result, dtype=torch.int32)
299
+ sorted_tensor = torch.sort(result_tensor, descending=True).values
300
+
301
+ return sorted_tensor
302
+
303
+
304
+ def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
305
+ if isinstance(timesteps, list):
306
+ # If timesteps is a list, we assume each element is a tensor
307
+ timesteps_np = [t.cpu().numpy() for t in timesteps]
308
+ device = timesteps[0].device
309
+ else:
310
+ # If timesteps is a tensor
311
+ timesteps_np = timesteps.cpu().numpy()
312
+ device = timesteps.device
313
+
314
+ num_bins = len(diffusion_skip_timestep)
315
+
316
+ if isinstance(timesteps_np, list):
317
+ bin_size = len(timesteps_np) // num_bins
318
+ new_timesteps = []
319
+
320
+ for i in range(num_bins):
321
+ bin_start = i * bin_size
322
+ bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
323
+ bin_timesteps = timesteps_np[bin_start:bin_end]
324
+
325
+ if diffusion_skip_timestep[i] == 0:
326
+ # If the bin is marked with 0, keep all timesteps
327
+ new_timesteps.extend(bin_timesteps)
328
+ elif diffusion_skip_timestep[i] == 1:
329
+ # If the bin is marked with 1, omit the last timestep in the bin
330
+ new_timesteps.extend(bin_timesteps[1:])
331
+
332
+ new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
333
+ else:
334
+ bin_size = len(timesteps_np) // num_bins
335
+ new_timesteps = []
336
+
337
+ for i in range(num_bins):
338
+ bin_start = i * bin_size
339
+ bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
340
+ bin_timesteps = timesteps_np[bin_start:bin_end]
341
+
342
+ if diffusion_skip_timestep[i] == 0:
343
+ # If the bin is marked with 0, keep all timesteps
344
+ new_timesteps.extend(bin_timesteps)
345
+ elif diffusion_skip_timestep[i] == 1:
346
+ # If the bin is marked with 1, omit the last timestep in the bin
347
+ new_timesteps.extend(bin_timesteps[1:])
348
+ elif diffusion_skip_timestep[i] != 0:
349
+ # If the bin is marked with a non-zero value, randomly omit n timesteps
350
+ if len(bin_timesteps) > diffusion_skip_timestep[i]:
351
+ indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
352
+ timesteps_to_keep = [
353
+ timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
354
+ ]
355
+ else:
356
+ timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素
357
+ new_timesteps.extend(timesteps_to_keep)
358
+
359
+ new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
360
+
361
+ if isinstance(timesteps, list):
362
+ return new_timesteps_tensor
363
+ else:
364
+ return new_timesteps_tensor
videosys/core/parallel_mgr.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from colossalai.cluster.process_group_mesh import ProcessGroupMesh
6
+ from torch.distributed import ProcessGroup
7
+
8
+ from videosys.utils.logging import init_dist_logger, logger
9
+ from videosys.utils.utils import set_seed
10
+
11
+ PARALLEL_MANAGER = None
12
+
13
+
14
+ class ParallelManager(ProcessGroupMesh):
15
+ def __init__(self, dp_size, cp_size, sp_size):
16
+ super().__init__(dp_size, cp_size, sp_size)
17
+ dp_axis, cp_axis, sp_axis = 0, 1, 2
18
+
19
+ self.dp_size = dp_size
20
+ self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
21
+ self.dp_rank = dist.get_rank(self.dp_group)
22
+
23
+ self.cp_size = cp_size
24
+ self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
25
+ self.cp_rank = dist.get_rank(self.cp_group)
26
+
27
+ self.sp_size = sp_size
28
+ self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
29
+ self.sp_rank = dist.get_rank(self.sp_group)
30
+ self.enable_sp = sp_size > 1
31
+
32
+ logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
33
+
34
+
35
+ def set_parallel_manager(dp_size, cp_size, sp_size):
36
+ global PARALLEL_MANAGER
37
+ PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
38
+
39
+
40
+ def get_data_parallel_group():
41
+ return PARALLEL_MANAGER.dp_group
42
+
43
+
44
+ def get_data_parallel_size():
45
+ return PARALLEL_MANAGER.dp_size
46
+
47
+
48
+ def get_data_parallel_rank():
49
+ return PARALLEL_MANAGER.dp_rank
50
+
51
+
52
+ def get_sequence_parallel_group():
53
+ return PARALLEL_MANAGER.sp_group
54
+
55
+
56
+ def get_sequence_parallel_size():
57
+ return PARALLEL_MANAGER.sp_size
58
+
59
+
60
+ def get_sequence_parallel_rank():
61
+ return PARALLEL_MANAGER.sp_rank
62
+
63
+
64
+ def get_cfg_parallel_group():
65
+ return PARALLEL_MANAGER.cp_group
66
+
67
+
68
+ def get_cfg_parallel_size():
69
+ return PARALLEL_MANAGER.cp_size
70
+
71
+
72
+ def enable_sequence_parallel():
73
+ if PARALLEL_MANAGER is None:
74
+ return False
75
+ return PARALLEL_MANAGER.enable_sp
76
+
77
+
78
+ def get_parallel_manager():
79
+ return PARALLEL_MANAGER
80
+
81
+
82
+ def initialize(
83
+ rank=0,
84
+ world_size=1,
85
+ init_method=None,
86
+ seed: Optional[int] = None,
87
+ sp_size: Optional[int] = None,
88
+ enable_cp: bool = True,
89
+ ):
90
+ if not dist.is_initialized():
91
+ try:
92
+ dist.destroy_process_group()
93
+ except Exception:
94
+ pass
95
+ dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
96
+ torch.cuda.set_device(rank)
97
+ init_dist_logger()
98
+ torch.backends.cuda.matmul.allow_tf32 = True
99
+ torch.backends.cudnn.allow_tf32 = True
100
+
101
+ # init sequence parallel
102
+ if sp_size is None:
103
+ sp_size = dist.get_world_size()
104
+ dp_size = 1
105
+ else:
106
+ assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
107
+ dp_size = dist.get_world_size() // sp_size
108
+
109
+ # update cfg parallel
110
+ if enable_cp and sp_size % 2 == 0:
111
+ sp_size = sp_size // 2
112
+ cp_size = 2
113
+ else:
114
+ cp_size = 1
115
+
116
+ set_parallel_manager(dp_size, cp_size, sp_size)
117
+
118
+ if seed is not None:
119
+ set_seed(seed + get_data_parallel_rank())
videosys/core/pipeline.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
6
+ from diffusers.utils import BaseOutput
7
+
8
+
9
+ class VideoSysPipeline(DiffusionPipeline):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ @staticmethod
14
+ def set_eval_and_device(device: torch.device, *modules):
15
+ for module in modules:
16
+ module.eval()
17
+ module.to(device)
18
+
19
+ @abstractmethod
20
+ def generate(self, *args, **kwargs):
21
+ pass
22
+
23
+ def __call__(self, *args, **kwargs):
24
+ """
25
+ In diffusers, it is a convention to call the pipeline object.
26
+ But in VideoSys, we will use the generate method for better prompt.
27
+ This is a wrapper for the generate method to support the diffusers usage.
28
+ """
29
+ return self.generate(*args, **kwargs)
30
+
31
+
32
+ @dataclass
33
+ class VideoSysPipelineOutput(BaseOutput):
34
+ video: torch.Tensor
videosys/core/shardformer/__init__.py ADDED
File without changes
videosys/core/shardformer/t5/__init__.py ADDED
File without changes
videosys/core/shardformer/t5/modeling.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class T5LayerNorm(nn.Module):
6
+ def __init__(self, hidden_size, eps=1e-6):
7
+ """
8
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
9
+ """
10
+ super().__init__()
11
+ self.weight = nn.Parameter(torch.ones(hidden_size))
12
+ self.variance_epsilon = eps
13
+
14
+ def forward(self, hidden_states):
15
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
16
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
17
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
18
+ # half-precision inputs is done in fp32
19
+
20
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
21
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
22
+
23
+ # convert into half-precision if necessary
24
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
25
+ hidden_states = hidden_states.to(self.weight.dtype)
26
+
27
+ return self.weight * hidden_states
28
+
29
+ @staticmethod
30
+ def from_native_module(module, *args, **kwargs):
31
+ assert module.__class__.__name__ == "FusedRMSNorm", (
32
+ "Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
33
+ "Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
34
+ )
35
+
36
+ layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
37
+ layer_norm.weight.data.copy_(module.weight.data)
38
+ layer_norm = layer_norm.to(module.weight.device)
39
+ return layer_norm
videosys/core/shardformer/t5/policy.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
2
+ from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
3
+ from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
4
+
5
+
6
+ class T5EncoderPolicy(Policy):
7
+ def config_sanity_check(self):
8
+ assert not self.shard_config.enable_tensor_parallelism
9
+ assert not self.shard_config.enable_flash_attention
10
+
11
+ def preprocess(self):
12
+ return self.model
13
+
14
+ def module_policy(self):
15
+ from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
16
+
17
+ policy = {}
18
+
19
+ # check whether apex is installed
20
+ try:
21
+ from apex.normalization import FusedRMSNorm # noqa
22
+ from videosys.core.shardformer.t5.modeling import T5LayerNorm
23
+
24
+ # recover hf from fused rms norm to T5 norm which is faster
25
+ self.append_or_create_submodule_replacement(
26
+ description=SubModuleReplacementDescription(
27
+ suffix="layer_norm",
28
+ target_module=T5LayerNorm,
29
+ ),
30
+ policy=policy,
31
+ target_key=T5LayerFF,
32
+ )
33
+ self.append_or_create_submodule_replacement(
34
+ description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
35
+ policy=policy,
36
+ target_key=T5LayerSelfAttention,
37
+ )
38
+ self.append_or_create_submodule_replacement(
39
+ description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
40
+ policy=policy,
41
+ target_key=T5Stack,
42
+ )
43
+ except (ImportError, ModuleNotFoundError):
44
+ pass
45
+
46
+ # use jit operator
47
+ if self.shard_config.enable_jit_fused:
48
+ self.append_or_create_method_replacement(
49
+ description={
50
+ "forward": get_jit_fused_T5_layer_ff_forward(),
51
+ "dropout_add": get_jit_fused_dropout_add_func(),
52
+ },
53
+ policy=policy,
54
+ target_key=T5LayerFF,
55
+ )
56
+ self.append_or_create_method_replacement(
57
+ description={
58
+ "forward": get_T5_layer_self_attention_forward(),
59
+ "dropout_add": get_jit_fused_dropout_add_func(),
60
+ },
61
+ policy=policy,
62
+ target_key=T5LayerSelfAttention,
63
+ )
64
+
65
+ return policy
66
+
67
+ def postprocess(self):
68
+ return self.model
videosys/datasets/dataloader.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Iterator, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+
9
+ from videosys.core.parallel_mgr import ParallelManager
10
+
11
+
12
+ class StatefulDistributedSampler(DistributedSampler):
13
+ def __init__(
14
+ self,
15
+ dataset: Dataset,
16
+ num_replicas: Optional[int] = None,
17
+ rank: Optional[int] = None,
18
+ shuffle: bool = True,
19
+ seed: int = 0,
20
+ drop_last: bool = False,
21
+ ) -> None:
22
+ super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
23
+ self.start_index: int = 0
24
+
25
+ def __iter__(self) -> Iterator:
26
+ iterator = super().__iter__()
27
+ indices = list(iterator)
28
+ indices = indices[self.start_index :]
29
+ return iter(indices)
30
+
31
+ def __len__(self) -> int:
32
+ return self.num_samples - self.start_index
33
+
34
+ def set_start_index(self, start_index: int) -> None:
35
+ self.start_index = start_index
36
+
37
+
38
+ def prepare_dataloader(
39
+ dataset,
40
+ batch_size,
41
+ shuffle=False,
42
+ seed=1024,
43
+ drop_last=False,
44
+ pin_memory=False,
45
+ num_workers=0,
46
+ pg_manager: Optional[ParallelManager] = None,
47
+ **kwargs,
48
+ ):
49
+ r"""
50
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
51
+ `torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
52
+
53
+
54
+ Args:
55
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
56
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
57
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
58
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
59
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
60
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
61
+ the batch size, then the last batch will be smaller, defaults to False.
62
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
63
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
64
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
65
+ `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
66
+
67
+ Returns:
68
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
69
+ """
70
+ _kwargs = kwargs.copy()
71
+ sampler = StatefulDistributedSampler(
72
+ dataset,
73
+ num_replicas=pg_manager.size(pg_manager.dp_axis),
74
+ rank=pg_manager.coordinate(pg_manager.dp_axis),
75
+ shuffle=shuffle,
76
+ )
77
+
78
+ # Deterministic dataloader
79
+ def seed_worker(worker_id):
80
+ worker_seed = seed
81
+ np.random.seed(worker_seed)
82
+ torch.manual_seed(worker_seed)
83
+ random.seed(worker_seed)
84
+
85
+ return DataLoader(
86
+ dataset,
87
+ batch_size=batch_size,
88
+ sampler=sampler,
89
+ worker_init_fn=seed_worker,
90
+ drop_last=drop_last,
91
+ pin_memory=pin_memory,
92
+ num_workers=num_workers,
93
+ **_kwargs,
94
+ )
videosys/datasets/image_transform.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from DiT
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # DiT: https://github.com/facebookresearch/DiT
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import numpy as np
12
+ import torchvision.transforms as transforms
13
+ from PIL import Image
14
+
15
+
16
+ def center_crop_arr(pil_image, image_size):
17
+ """
18
+ Center cropping implementation from ADM.
19
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
20
+ """
21
+ while min(*pil_image.size) >= 2 * image_size:
22
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
23
+
24
+ scale = image_size / min(*pil_image.size)
25
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
26
+
27
+ arr = np.array(pil_image)
28
+ crop_y = (arr.shape[0] - image_size) // 2
29
+ crop_x = (arr.shape[1] - image_size) // 2
30
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
31
+
32
+
33
+ def get_transforms_image(image_size=256):
34
+ transform = transforms.Compose(
35
+ [
36
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
37
+ transforms.RandomHorizontalFlip(),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
40
+ ]
41
+ )
42
+ return transform
videosys/datasets/video_transform.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from OpenSora and Latte
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # OpenSora: https://github.com/hpcaitech/Open-Sora
8
+ # Latte: https://github.com/Vchitect/Latte
9
+ # --------------------------------------------------------
10
+
11
+ import numbers
12
+ import random
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+
18
+
19
+ def _is_tensor_video_clip(clip):
20
+ if not torch.is_tensor(clip):
21
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
22
+
23
+ if not clip.ndimension() == 4:
24
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
25
+
26
+ return True
27
+
28
+
29
+ def center_crop_arr(pil_image, image_size):
30
+ """
31
+ Center cropping implementation from ADM.
32
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
33
+ """
34
+ while min(*pil_image.size) >= 2 * image_size:
35
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
36
+
37
+ scale = image_size / min(*pil_image.size)
38
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
39
+
40
+ arr = np.array(pil_image)
41
+ crop_y = (arr.shape[0] - image_size) // 2
42
+ crop_x = (arr.shape[1] - image_size) // 2
43
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
44
+
45
+
46
+ def crop(clip, i, j, h, w):
47
+ """
48
+ Args:
49
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
50
+ """
51
+ if len(clip.size()) != 4:
52
+ raise ValueError("clip should be a 4D tensor")
53
+ return clip[..., i : i + h, j : j + w]
54
+
55
+
56
+ def resize(clip, target_size, interpolation_mode):
57
+ if len(target_size) != 2:
58
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
59
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
60
+
61
+
62
+ def resize_scale(clip, target_size, interpolation_mode):
63
+ if len(target_size) != 2:
64
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
65
+ H, W = clip.size(-2), clip.size(-1)
66
+ scale_ = target_size[0] / min(H, W)
67
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
68
+
69
+
70
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
71
+ """
72
+ Do spatial cropping and resizing to the video clip
73
+ Args:
74
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
75
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
76
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
77
+ h (int): Height of the cropped region.
78
+ w (int): Width of the cropped region.
79
+ size (tuple(int, int)): height and width of resized clip
80
+ Returns:
81
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
82
+ """
83
+ if not _is_tensor_video_clip(clip):
84
+ raise ValueError("clip should be a 4D torch.tensor")
85
+ clip = crop(clip, i, j, h, w)
86
+ clip = resize(clip, size, interpolation_mode)
87
+ return clip
88
+
89
+
90
+ def center_crop(clip, crop_size):
91
+ if not _is_tensor_video_clip(clip):
92
+ raise ValueError("clip should be a 4D torch.tensor")
93
+ h, w = clip.size(-2), clip.size(-1)
94
+ th, tw = crop_size
95
+ if h < th or w < tw:
96
+ raise ValueError("height and width must be no smaller than crop_size")
97
+
98
+ i = int(round((h - th) / 2.0))
99
+ j = int(round((w - tw) / 2.0))
100
+ return crop(clip, i, j, th, tw)
101
+
102
+
103
+ def center_crop_using_short_edge(clip):
104
+ if not _is_tensor_video_clip(clip):
105
+ raise ValueError("clip should be a 4D torch.tensor")
106
+ h, w = clip.size(-2), clip.size(-1)
107
+ if h < w:
108
+ th, tw = h, h
109
+ i = 0
110
+ j = int(round((w - tw) / 2.0))
111
+ else:
112
+ th, tw = w, w
113
+ i = int(round((h - th) / 2.0))
114
+ j = 0
115
+ return crop(clip, i, j, th, tw)
116
+
117
+
118
+ def random_shift_crop(clip):
119
+ """
120
+ Slide along the long edge, with the short edge as crop size
121
+ """
122
+ if not _is_tensor_video_clip(clip):
123
+ raise ValueError("clip should be a 4D torch.tensor")
124
+ h, w = clip.size(-2), clip.size(-1)
125
+
126
+ if h <= w:
127
+ short_edge = h
128
+ else:
129
+ short_edge = w
130
+
131
+ th, tw = short_edge, short_edge
132
+
133
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
134
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
135
+ return crop(clip, i, j, th, tw)
136
+
137
+
138
+ def to_tensor(clip):
139
+ """
140
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
141
+ permute the dimensions of clip tensor
142
+ Args:
143
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
144
+ Return:
145
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
146
+ """
147
+ _is_tensor_video_clip(clip)
148
+ if not clip.dtype == torch.uint8:
149
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
150
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
151
+ return clip.float() / 255.0
152
+
153
+
154
+ def normalize(clip, mean, std, inplace=False):
155
+ """
156
+ Args:
157
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
158
+ mean (tuple): pixel RGB mean. Size is (3)
159
+ std (tuple): pixel standard deviation. Size is (3)
160
+ Returns:
161
+ normalized clip (torch.tensor): Size is (T, C, H, W)
162
+ """
163
+ if not _is_tensor_video_clip(clip):
164
+ raise ValueError("clip should be a 4D torch.tensor")
165
+ if not inplace:
166
+ clip = clip.clone()
167
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
168
+ # print(mean)
169
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
170
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
171
+ return clip
172
+
173
+
174
+ def hflip(clip):
175
+ """
176
+ Args:
177
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
178
+ Returns:
179
+ flipped clip (torch.tensor): Size is (T, C, H, W)
180
+ """
181
+ if not _is_tensor_video_clip(clip):
182
+ raise ValueError("clip should be a 4D torch.tensor")
183
+ return clip.flip(-1)
184
+
185
+
186
+ class RandomCropVideo:
187
+ def __init__(self, size):
188
+ if isinstance(size, numbers.Number):
189
+ self.size = (int(size), int(size))
190
+ else:
191
+ self.size = size
192
+
193
+ def __call__(self, clip):
194
+ """
195
+ Args:
196
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
197
+ Returns:
198
+ torch.tensor: randomly cropped video clip.
199
+ size is (T, C, OH, OW)
200
+ """
201
+ i, j, h, w = self.get_params(clip)
202
+ return crop(clip, i, j, h, w)
203
+
204
+ def get_params(self, clip):
205
+ h, w = clip.shape[-2:]
206
+ th, tw = self.size
207
+
208
+ if h < th or w < tw:
209
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
210
+
211
+ if w == tw and h == th:
212
+ return 0, 0, h, w
213
+
214
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
215
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
216
+
217
+ return i, j, th, tw
218
+
219
+ def __repr__(self) -> str:
220
+ return f"{self.__class__.__name__}(size={self.size})"
221
+
222
+
223
+ class CenterCropResizeVideo:
224
+ """
225
+ First use the short side for cropping length,
226
+ center crop video, then resize to the specified size
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ size,
232
+ interpolation_mode="bilinear",
233
+ ):
234
+ if isinstance(size, tuple):
235
+ if len(size) != 2:
236
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
237
+ self.size = size
238
+ else:
239
+ self.size = (size, size)
240
+
241
+ self.interpolation_mode = interpolation_mode
242
+
243
+ def __call__(self, clip):
244
+ """
245
+ Args:
246
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
247
+ Returns:
248
+ torch.tensor: scale resized / center cropped video clip.
249
+ size is (T, C, crop_size, crop_size)
250
+ """
251
+ clip_center_crop = center_crop_using_short_edge(clip)
252
+ clip_center_crop_resize = resize(
253
+ clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
254
+ )
255
+ return clip_center_crop_resize
256
+
257
+ def __repr__(self) -> str:
258
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
259
+
260
+
261
+ class UCFCenterCropVideo:
262
+ """
263
+ First scale to the specified size in equal proportion to the short edge,
264
+ then center cropping
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ size,
270
+ interpolation_mode="bilinear",
271
+ ):
272
+ if isinstance(size, tuple):
273
+ if len(size) != 2:
274
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
275
+ self.size = size
276
+ else:
277
+ self.size = (size, size)
278
+
279
+ self.interpolation_mode = interpolation_mode
280
+
281
+ def __call__(self, clip):
282
+ """
283
+ Args:
284
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
285
+ Returns:
286
+ torch.tensor: scale resized / center cropped video clip.
287
+ size is (T, C, crop_size, crop_size)
288
+ """
289
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
290
+ clip_center_crop = center_crop(clip_resize, self.size)
291
+ return clip_center_crop
292
+
293
+ def __repr__(self) -> str:
294
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
295
+
296
+
297
+ class KineticsRandomCropResizeVideo:
298
+ """
299
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ size,
305
+ interpolation_mode="bilinear",
306
+ ):
307
+ if isinstance(size, tuple):
308
+ if len(size) != 2:
309
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
310
+ self.size = size
311
+ else:
312
+ self.size = (size, size)
313
+
314
+ self.interpolation_mode = interpolation_mode
315
+
316
+ def __call__(self, clip):
317
+ clip_random_crop = random_shift_crop(clip)
318
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
319
+ return clip_resize
320
+
321
+
322
+ class CenterCropVideo:
323
+ def __init__(
324
+ self,
325
+ size,
326
+ interpolation_mode="bilinear",
327
+ ):
328
+ if isinstance(size, tuple):
329
+ if len(size) != 2:
330
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
331
+ self.size = size
332
+ else:
333
+ self.size = (size, size)
334
+
335
+ self.interpolation_mode = interpolation_mode
336
+
337
+ def __call__(self, clip):
338
+ """
339
+ Args:
340
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
341
+ Returns:
342
+ torch.tensor: center cropped video clip.
343
+ size is (T, C, crop_size, crop_size)
344
+ """
345
+ clip_center_crop = center_crop(clip, self.size)
346
+ return clip_center_crop
347
+
348
+ def __repr__(self) -> str:
349
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
350
+
351
+
352
+ class NormalizeVideo:
353
+ """
354
+ Normalize the video clip by mean subtraction and division by standard deviation
355
+ Args:
356
+ mean (3-tuple): pixel RGB mean
357
+ std (3-tuple): pixel RGB standard deviation
358
+ inplace (boolean): whether do in-place normalization
359
+ """
360
+
361
+ def __init__(self, mean, std, inplace=False):
362
+ self.mean = mean
363
+ self.std = std
364
+ self.inplace = inplace
365
+
366
+ def __call__(self, clip):
367
+ """
368
+ Args:
369
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
370
+ """
371
+ return normalize(clip, self.mean, self.std, self.inplace)
372
+
373
+ def __repr__(self) -> str:
374
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
375
+
376
+
377
+ class ToTensorVideo:
378
+ """
379
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
380
+ permute the dimensions of clip tensor
381
+ """
382
+
383
+ def __init__(self):
384
+ pass
385
+
386
+ def __call__(self, clip):
387
+ """
388
+ Args:
389
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
390
+ Return:
391
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
392
+ """
393
+ return to_tensor(clip)
394
+
395
+ def __repr__(self) -> str:
396
+ return self.__class__.__name__
397
+
398
+
399
+ class RandomHorizontalFlipVideo:
400
+ """
401
+ Flip the video clip along the horizontal direction with a given probability
402
+ Args:
403
+ p (float): probability of the clip being flipped. Default value is 0.5
404
+ """
405
+
406
+ def __init__(self, p=0.5):
407
+ self.p = p
408
+
409
+ def __call__(self, clip):
410
+ """
411
+ Args:
412
+ clip (torch.tensor): Size is (T, C, H, W)
413
+ Return:
414
+ clip (torch.tensor): Size is (T, C, H, W)
415
+ """
416
+ if random.random() < self.p:
417
+ clip = hflip(clip)
418
+ return clip
419
+
420
+ def __repr__(self) -> str:
421
+ return f"{self.__class__.__name__}(p={self.p})"
422
+
423
+
424
+ # ------------------------------------------------------------
425
+ # --------------------- Sampling ---------------------------
426
+ # ------------------------------------------------------------
427
+ class TemporalRandomCrop(object):
428
+ """Temporally crop the given frame indices at a random location.
429
+
430
+ Args:
431
+ size (int): Desired length of frames will be seen in the model.
432
+ """
433
+
434
+ def __init__(self, size):
435
+ self.size = size
436
+
437
+ def __call__(self, total_frames):
438
+ rand_end = max(0, total_frames - self.size - 1)
439
+ begin_index = random.randint(0, rand_end)
440
+ end_index = min(begin_index + self.size, total_frames)
441
+ return begin_index, end_index
videosys/diffusion/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos and Meta DiT
2
+ # DiT: https://github.com/facebookresearch/DiT/tree/main
3
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6
+
7
+ from . import gaussian_diffusion as gd
8
+ from .respace import SpacedDiffusion, space_timesteps
9
+
10
+
11
+ def create_diffusion(
12
+ timestep_respacing,
13
+ noise_schedule="linear",
14
+ use_kl=False,
15
+ sigma_small=False,
16
+ predict_xstart=False,
17
+ learn_sigma=True,
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000,
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
34
+ model_var_type=(
35
+ (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
36
+ if not learn_sigma
37
+ else gd.ModelVarType.LEARNED_RANGE
38
+ ),
39
+ loss_type=loss_type
40
+ # rescale_timesteps=rescale_timesteps,
41
+ )
videosys/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
26
+
27
+ return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
28
+
29
+
30
+ def approx_standard_normal_cdf(x):
31
+ """
32
+ A fast approximation of the cumulative distribution function of the
33
+ standard normal.
34
+ """
35
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
36
+
37
+
38
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
39
+ """
40
+ Compute the log-likelihood of a continuous Gaussian distribution.
41
+ :param x: the targets
42
+ :param means: the Gaussian mean Tensor.
43
+ :param log_scales: the Gaussian log stddev Tensor.
44
+ :return: a tensor like x of log probabilities (in nats).
45
+ """
46
+ centered_x = x - means
47
+ inv_stdv = th.exp(-log_scales)
48
+ normalized_x = centered_x * inv_stdv
49
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
50
+ return log_probs
51
+
52
+
53
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
54
+ """
55
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
56
+ given image.
57
+ :param x: the target images. It is assumed that this was uint8 values,
58
+ rescaled to the range [-1, 1].
59
+ :param means: the Gaussian mean Tensor.
60
+ :param log_scales: the Gaussian log stddev Tensor.
61
+ :return: a tensor like x of log probabilities (in nats).
62
+ """
63
+ assert x.shape == means.shape == log_scales.shape
64
+ centered_x = x - means
65
+ inv_stdv = th.exp(-log_scales)
66
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
67
+ cdf_plus = approx_standard_normal_cdf(plus_in)
68
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
69
+ cdf_min = approx_standard_normal_cdf(min_in)
70
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
71
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
72
+ cdf_delta = cdf_plus - cdf_min
73
+ log_probs = th.where(
74
+ x < -0.999,
75
+ log_cdf_plus,
76
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
77
+ )
78
+ assert log_probs.shape == x.shape
79
+ return log_probs
videosys/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import enum
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch as th
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
49
+ KL = enum.auto() # use the variational lower-bound
50
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
51
+
52
+ def is_vb(self):
53
+ return self == LossType.KL or self == LossType.RESCALED_KL
54
+
55
+
56
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
57
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
58
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
59
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
60
+ return betas
61
+
62
+
63
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
64
+ """
65
+ This is the deprecated API for creating beta schedules.
66
+ See get_named_beta_schedule() for the new library of schedules.
67
+ """
68
+ if beta_schedule == "quad":
69
+ betas = (
70
+ np.linspace(
71
+ beta_start**0.5,
72
+ beta_end**0.5,
73
+ num_diffusion_timesteps,
74
+ dtype=np.float64,
75
+ )
76
+ ** 2
77
+ )
78
+ elif beta_schedule == "linear":
79
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
80
+ elif beta_schedule == "warmup10":
81
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
82
+ elif beta_schedule == "warmup50":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
84
+ elif beta_schedule == "const":
85
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
86
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
87
+ betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
88
+ else:
89
+ raise NotImplementedError(beta_schedule)
90
+ assert betas.shape == (num_diffusion_timesteps,)
91
+ return betas
92
+
93
+
94
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
95
+ """
96
+ Get a pre-defined beta schedule for the given name.
97
+ The beta schedule library consists of beta schedules which remain similar
98
+ in the limit of num_diffusion_timesteps.
99
+ Beta schedules may be added, but should not be removed or changed once
100
+ they are committed to maintain backwards compatibility.
101
+ """
102
+ if schedule_name == "linear":
103
+ # Linear schedule from Ho et al, extended to work for any number of
104
+ # diffusion steps.
105
+ scale = 1000 / num_diffusion_timesteps
106
+ return get_beta_schedule(
107
+ "linear",
108
+ beta_start=scale * 0.0001,
109
+ beta_end=scale * 0.02,
110
+ num_diffusion_timesteps=num_diffusion_timesteps,
111
+ )
112
+ elif schedule_name == "squaredcos_cap_v2":
113
+ return betas_for_alpha_bar(
114
+ num_diffusion_timesteps,
115
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
116
+ )
117
+ else:
118
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
119
+
120
+
121
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
122
+ """
123
+ Create a beta schedule that discretizes the given alpha_t_bar function,
124
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
125
+ :param num_diffusion_timesteps: the number of betas to produce.
126
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
127
+ produces the cumulative product of (1-beta) up to that
128
+ part of the diffusion process.
129
+ :param max_beta: the maximum beta to use; use values lower than 1 to
130
+ prevent singularities.
131
+ """
132
+ betas = []
133
+ for i in range(num_diffusion_timesteps):
134
+ t1 = i / num_diffusion_timesteps
135
+ t2 = (i + 1) / num_diffusion_timesteps
136
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
137
+ return np.array(betas)
138
+
139
+
140
+ class GaussianDiffusion:
141
+ """
142
+ Utilities for training and sampling diffusion models.
143
+ Original ported from this codebase:
144
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
145
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
146
+ starting at T and going to 1.
147
+ """
148
+
149
+ def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
150
+ self.model_mean_type = model_mean_type
151
+ self.model_var_type = model_var_type
152
+ self.loss_type = loss_type
153
+
154
+ # Use float64 for accuracy.
155
+ betas = np.array(betas, dtype=np.float64)
156
+ self.betas = betas
157
+ assert len(betas.shape) == 1, "betas must be 1-D"
158
+ assert (betas > 0).all() and (betas <= 1).all()
159
+
160
+ self.num_timesteps = int(betas.shape[0])
161
+
162
+ alphas = 1.0 - betas
163
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
164
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
165
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
166
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
167
+
168
+ # calculations for diffusion q(x_t | x_{t-1}) and others
169
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
170
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
171
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
172
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
173
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
174
+
175
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
176
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
177
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
178
+ self.posterior_log_variance_clipped = (
179
+ np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
180
+ if len(self.posterior_variance) > 1
181
+ else np.array([])
182
+ )
183
+
184
+ self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
185
+ self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
186
+
187
+ def q_mean_variance(self, x_start, t):
188
+ """
189
+ Get the distribution q(x_t | x_0).
190
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
191
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
192
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
193
+ """
194
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
195
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
196
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
197
+ return mean, variance, log_variance
198
+
199
+ def q_sample(self, x_start, t, noise=None):
200
+ """
201
+ Diffuse the data for a given number of diffusion steps.
202
+ In other words, sample from q(x_t | x_0).
203
+ :param x_start: the initial data batch.
204
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
205
+ :param noise: if specified, the split-out normal noise.
206
+ :return: A noisy version of x_start.
207
+ """
208
+ if noise is None:
209
+ noise = th.randn_like(x_start)
210
+ assert noise.shape == x_start.shape
211
+ return (
212
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
213
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
214
+ )
215
+
216
+ def q_posterior_mean_variance(self, x_start, x_t, t):
217
+ """
218
+ Compute the mean and variance of the diffusion posterior:
219
+ q(x_{t-1} | x_t, x_0)
220
+ """
221
+ assert x_start.shape == x_t.shape
222
+ posterior_mean = (
223
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
224
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
225
+ )
226
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
227
+ posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
228
+ assert (
229
+ posterior_mean.shape[0]
230
+ == posterior_variance.shape[0]
231
+ == posterior_log_variance_clipped.shape[0]
232
+ == x_start.shape[0]
233
+ )
234
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
235
+
236
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
237
+ """
238
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
239
+ the initial x, x_0.
240
+ :param model: the model, which takes a signal and a batch of timesteps
241
+ as input.
242
+ :param x: the [N x C x ...] tensor at time t.
243
+ :param t: a 1-D Tensor of timesteps.
244
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
245
+ :param denoised_fn: if not None, a function which applies to the
246
+ x_start prediction before it is used to sample. Applies before
247
+ clip_denoised.
248
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
249
+ pass to the model. This can be used for conditioning.
250
+ :return: a dict with the following keys:
251
+ - 'mean': the model mean output.
252
+ - 'variance': the model variance output.
253
+ - 'log_variance': the log of 'variance'.
254
+ - 'pred_xstart': the prediction for x_0.
255
+ """
256
+ if model_kwargs is None:
257
+ model_kwargs = {}
258
+
259
+ B, C = x.shape[:2]
260
+ assert t.shape == (B,)
261
+ model_output = model(x, t, **model_kwargs)
262
+ if isinstance(model_output, tuple):
263
+ model_output, extra = model_output
264
+ else:
265
+ extra = None
266
+
267
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
268
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
269
+ model_output, model_var_values = th.split(model_output, C, dim=1)
270
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
271
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
272
+ # The model_var_values is [-1, 1] for [min_var, max_var].
273
+ frac = (model_var_values + 1) / 2
274
+ model_log_variance = frac * max_log + (1 - frac) * min_log
275
+ model_variance = th.exp(model_log_variance)
276
+ else:
277
+ model_variance, model_log_variance = {
278
+ # for fixedlarge, we set the initial (log-)variance like so
279
+ # to get a better decoder log likelihood.
280
+ ModelVarType.FIXED_LARGE: (
281
+ np.append(self.posterior_variance[1], self.betas[1:]),
282
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
283
+ ),
284
+ ModelVarType.FIXED_SMALL: (
285
+ self.posterior_variance,
286
+ self.posterior_log_variance_clipped,
287
+ ),
288
+ }[self.model_var_type]
289
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
290
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
291
+
292
+ def process_xstart(x):
293
+ if denoised_fn is not None:
294
+ x = denoised_fn(x)
295
+ if clip_denoised:
296
+ return x.clamp(-1, 1)
297
+ return x
298
+
299
+ if self.model_mean_type == ModelMeanType.START_X:
300
+ pred_xstart = process_xstart(model_output)
301
+ else:
302
+ pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
303
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
304
+
305
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
306
+ return {
307
+ "mean": model_mean,
308
+ "variance": model_variance,
309
+ "log_variance": model_log_variance,
310
+ "pred_xstart": pred_xstart,
311
+ "extra": extra,
312
+ }
313
+
314
+ def _predict_xstart_from_eps(self, x_t, t, eps):
315
+ assert x_t.shape == eps.shape
316
+ return (
317
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
318
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
319
+ )
320
+
321
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
322
+ return (
323
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
324
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
325
+
326
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
327
+ """
328
+ Compute the mean for the previous step, given a function cond_fn that
329
+ computes the gradient of a conditional log probability with respect to
330
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
331
+ condition on y.
332
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
333
+ """
334
+ gradient = cond_fn(x, t, **model_kwargs)
335
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
336
+ return new_mean
337
+
338
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
339
+ """
340
+ Compute what the p_mean_variance output would have been, should the
341
+ model's score function be conditioned by cond_fn.
342
+ See condition_mean() for details on cond_fn.
343
+ Unlike condition_mean(), this instead uses the conditioning strategy
344
+ from Song et al (2020).
345
+ """
346
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
347
+
348
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
349
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
350
+
351
+ out = p_mean_var.copy()
352
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
353
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
354
+ return out
355
+
356
+ def p_sample(
357
+ self,
358
+ model,
359
+ x,
360
+ t,
361
+ clip_denoised=True,
362
+ denoised_fn=None,
363
+ cond_fn=None,
364
+ model_kwargs=None,
365
+ ):
366
+ """
367
+ Sample x_{t-1} from the model at the given timestep.
368
+ :param model: the model to sample from.
369
+ :param x: the current tensor at x_{t-1}.
370
+ :param t: the value of t, starting at 0 for the first diffusion step.
371
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
372
+ :param denoised_fn: if not None, a function which applies to the
373
+ x_start prediction before it is used to sample.
374
+ :param cond_fn: if not None, this is a gradient function that acts
375
+ similarly to the model.
376
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
377
+ pass to the model. This can be used for conditioning.
378
+ :return: a dict containing the following keys:
379
+ - 'sample': a random sample from the model.
380
+ - 'pred_xstart': a prediction of x_0.
381
+ """
382
+ out = self.p_mean_variance(
383
+ model,
384
+ x,
385
+ t,
386
+ clip_denoised=clip_denoised,
387
+ denoised_fn=denoised_fn,
388
+ model_kwargs=model_kwargs,
389
+ )
390
+ noise = th.randn_like(x)
391
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
392
+ if cond_fn is not None:
393
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
394
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
395
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
396
+
397
+ def p_sample_loop(
398
+ self,
399
+ model,
400
+ shape,
401
+ noise=None,
402
+ clip_denoised=True,
403
+ denoised_fn=None,
404
+ cond_fn=None,
405
+ model_kwargs=None,
406
+ device=None,
407
+ progress=False,
408
+ ):
409
+ """
410
+ Generate samples from the model.
411
+ :param model: the model module.
412
+ :param shape: the shape of the samples, (N, C, H, W).
413
+ :param noise: if specified, the noise from the encoder to sample.
414
+ Should be of the same shape as `shape`.
415
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
416
+ :param denoised_fn: if not None, a function which applies to the
417
+ x_start prediction before it is used to sample.
418
+ :param cond_fn: if not None, this is a gradient function that acts
419
+ similarly to the model.
420
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
421
+ pass to the model. This can be used for conditioning.
422
+ :param device: if specified, the device to create the samples on.
423
+ If not specified, use a model parameter's device.
424
+ :param progress: if True, show a tqdm progress bar.
425
+ :return: a non-differentiable batch of samples.
426
+ """
427
+ final = None
428
+ for sample in self.p_sample_loop_progressive(
429
+ model,
430
+ shape,
431
+ noise=noise,
432
+ clip_denoised=clip_denoised,
433
+ denoised_fn=denoised_fn,
434
+ cond_fn=cond_fn,
435
+ model_kwargs=model_kwargs,
436
+ device=device,
437
+ progress=progress,
438
+ ):
439
+ final = sample
440
+ return final["sample"]
441
+
442
+ def p_sample_loop_progressive(
443
+ self,
444
+ model,
445
+ shape,
446
+ noise=None,
447
+ clip_denoised=True,
448
+ denoised_fn=None,
449
+ cond_fn=None,
450
+ model_kwargs=None,
451
+ device=None,
452
+ progress=False,
453
+ ):
454
+ """
455
+ Generate samples from the model and yield intermediate samples from
456
+ each timestep of diffusion.
457
+ Arguments are the same as p_sample_loop().
458
+ Returns a generator over dicts, where each dict is the return value of
459
+ p_sample().
460
+ """
461
+ if device is None:
462
+ device = next(model.parameters()).device
463
+ assert isinstance(shape, (tuple, list))
464
+ if noise is not None:
465
+ img = noise
466
+ else:
467
+ img = th.randn(*shape, device=device)
468
+ indices = list(range(self.num_timesteps))[::-1]
469
+
470
+ if progress:
471
+ # Lazy import so that we don't depend on tqdm.
472
+ from tqdm.auto import tqdm
473
+
474
+ indices = tqdm(indices)
475
+
476
+ for i in indices:
477
+ t = th.tensor([i] * shape[0], device=device)
478
+ with th.no_grad():
479
+ out = self.p_sample(
480
+ model,
481
+ img,
482
+ t,
483
+ clip_denoised=clip_denoised,
484
+ denoised_fn=denoised_fn,
485
+ cond_fn=cond_fn,
486
+ model_kwargs=model_kwargs,
487
+ )
488
+ yield out
489
+ img = out["sample"]
490
+
491
+ def ddim_sample(
492
+ self,
493
+ model,
494
+ x,
495
+ t,
496
+ clip_denoised=True,
497
+ denoised_fn=None,
498
+ cond_fn=None,
499
+ model_kwargs=None,
500
+ eta=0.0,
501
+ ):
502
+ """
503
+ Sample x_{t-1} from the model using DDIM.
504
+ Same usage as p_sample().
505
+ """
506
+ out = self.p_mean_variance(
507
+ model,
508
+ x,
509
+ t,
510
+ clip_denoised=clip_denoised,
511
+ denoised_fn=denoised_fn,
512
+ model_kwargs=model_kwargs,
513
+ )
514
+ if cond_fn is not None:
515
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
516
+
517
+ # Usually our model outputs epsilon, but we re-derive it
518
+ # in case we used x_start or x_prev prediction.
519
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
520
+
521
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
522
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
523
+ sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
524
+ # Equation 12.
525
+ noise = th.randn_like(x)
526
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
527
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
528
+ sample = mean_pred + nonzero_mask * sigma * noise
529
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
530
+
531
+ def ddim_reverse_sample(
532
+ self,
533
+ model,
534
+ x,
535
+ t,
536
+ clip_denoised=True,
537
+ denoised_fn=None,
538
+ cond_fn=None,
539
+ model_kwargs=None,
540
+ eta=0.0,
541
+ ):
542
+ """
543
+ Sample x_{t+1} from the model using DDIM reverse ODE.
544
+ """
545
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
546
+ out = self.p_mean_variance(
547
+ model,
548
+ x,
549
+ t,
550
+ clip_denoised=clip_denoised,
551
+ denoised_fn=denoised_fn,
552
+ model_kwargs=model_kwargs,
553
+ )
554
+ if cond_fn is not None:
555
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
556
+ # Usually our model outputs epsilon, but we re-derive it
557
+ # in case we used x_start or x_prev prediction.
558
+ eps = (
559
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
560
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
561
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
562
+
563
+ # Equation 12. reversed
564
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
565
+
566
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
567
+
568
+ def ddim_sample_loop(
569
+ self,
570
+ model,
571
+ shape,
572
+ noise=None,
573
+ clip_denoised=True,
574
+ denoised_fn=None,
575
+ cond_fn=None,
576
+ model_kwargs=None,
577
+ device=None,
578
+ progress=False,
579
+ eta=0.0,
580
+ ):
581
+ """
582
+ Generate samples from the model using DDIM.
583
+ Same usage as p_sample_loop().
584
+ """
585
+ final = None
586
+ for sample in self.ddim_sample_loop_progressive(
587
+ model,
588
+ shape,
589
+ noise=noise,
590
+ clip_denoised=clip_denoised,
591
+ denoised_fn=denoised_fn,
592
+ cond_fn=cond_fn,
593
+ model_kwargs=model_kwargs,
594
+ device=device,
595
+ progress=progress,
596
+ eta=eta,
597
+ ):
598
+ final = sample
599
+ return final["sample"]
600
+
601
+ def ddim_sample_loop_progressive(
602
+ self,
603
+ model,
604
+ shape,
605
+ noise=None,
606
+ clip_denoised=True,
607
+ denoised_fn=None,
608
+ cond_fn=None,
609
+ model_kwargs=None,
610
+ device=None,
611
+ progress=False,
612
+ eta=0.0,
613
+ ):
614
+ """
615
+ Use DDIM to sample from the model and yield intermediate samples from
616
+ each timestep of DDIM.
617
+ Same usage as p_sample_loop_progressive().
618
+ """
619
+ if device is None:
620
+ device = next(model.parameters()).device
621
+ assert isinstance(shape, (tuple, list))
622
+ if noise is not None:
623
+ img = noise
624
+ else:
625
+ img = th.randn(*shape, device=device)
626
+ indices = list(range(self.num_timesteps))[::-1]
627
+
628
+ if progress:
629
+ # Lazy import so that we don't depend on tqdm.
630
+ from tqdm.auto import tqdm
631
+
632
+ indices = tqdm(indices)
633
+
634
+ for i in indices:
635
+ t = th.tensor([i] * shape[0], device=device)
636
+ with th.no_grad():
637
+ out = self.ddim_sample(
638
+ model,
639
+ img,
640
+ t,
641
+ clip_denoised=clip_denoised,
642
+ denoised_fn=denoised_fn,
643
+ cond_fn=cond_fn,
644
+ model_kwargs=model_kwargs,
645
+ eta=eta,
646
+ )
647
+ yield out
648
+ img = out["sample"]
649
+
650
+ def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
651
+ """
652
+ Get a term for the variational lower-bound.
653
+ The resulting units are bits (rather than nats, as one might expect).
654
+ This allows for comparison to other papers.
655
+ :return: a dict with the following keys:
656
+ - 'output': a shape [N] tensor of NLLs or KLs.
657
+ - 'pred_xstart': the x_0 predictions.
658
+ """
659
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
660
+ out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
661
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
662
+ kl = mean_flat(kl) / np.log(2.0)
663
+
664
+ decoder_nll = -discretized_gaussian_log_likelihood(
665
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
666
+ )
667
+ assert decoder_nll.shape == x_start.shape
668
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
669
+
670
+ # At the first timestep return the decoder NLL,
671
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
672
+ output = th.where((t == 0), decoder_nll, kl)
673
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
674
+
675
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
676
+ """
677
+ Compute training losses for a single timestep.
678
+ :param model: the model to evaluate loss on.
679
+ :param x_start: the [N x C x ...] tensor of inputs.
680
+ :param t: a batch of timestep indices.
681
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
682
+ pass to the model. This can be used for conditioning.
683
+ :param noise: if specified, the specific Gaussian noise to try to remove.
684
+ :return: a dict with the key "loss" containing a tensor of shape [N].
685
+ Some mean or variance settings may also have other keys.
686
+ """
687
+ if model_kwargs is None:
688
+ model_kwargs = {}
689
+ if noise is None:
690
+ noise = th.randn_like(x_start)
691
+ x_t = self.q_sample(x_start, t, noise=noise)
692
+
693
+ terms = {}
694
+
695
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
696
+ terms["loss"] = self._vb_terms_bpd(
697
+ model=model,
698
+ x_start=x_start,
699
+ x_t=x_t,
700
+ t=t,
701
+ clip_denoised=False,
702
+ model_kwargs=model_kwargs,
703
+ )["output"]
704
+ if self.loss_type == LossType.RESCALED_KL:
705
+ terms["loss"] *= self.num_timesteps
706
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
707
+ model_output = model(x_t, t, **model_kwargs)
708
+
709
+ if self.model_var_type in [
710
+ ModelVarType.LEARNED,
711
+ ModelVarType.LEARNED_RANGE,
712
+ ]:
713
+ B, C = x_t.shape[:2]
714
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
715
+ model_output, model_var_values = th.split(model_output, C, dim=1)
716
+ # Learn the variance using the variational bound, but don't let
717
+ # it affect our mean prediction.
718
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
719
+ terms["vb"] = self._vb_terms_bpd(
720
+ model=lambda *args, r=frozen_out: r,
721
+ x_start=x_start,
722
+ x_t=x_t,
723
+ t=t,
724
+ clip_denoised=False,
725
+ )["output"]
726
+ if self.loss_type == LossType.RESCALED_MSE:
727
+ # Divide by 1000 for equivalence with initial implementation.
728
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
729
+ terms["vb"] *= self.num_timesteps / 1000.0
730
+
731
+ target = {
732
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
733
+ ModelMeanType.START_X: x_start,
734
+ ModelMeanType.EPSILON: noise,
735
+ }[self.model_mean_type]
736
+ assert model_output.shape == target.shape == x_start.shape
737
+ terms["mse"] = mean_flat((target - model_output) ** 2)
738
+ if "vb" in terms:
739
+ terms["loss"] = terms["mse"] + terms["vb"]
740
+ else:
741
+ terms["loss"] = terms["mse"]
742
+ else:
743
+ raise NotImplementedError(self.loss_type)
744
+
745
+ return terms
746
+
747
+ def _prior_bpd(self, x_start):
748
+ """
749
+ Get the prior KL term for the variational lower-bound, measured in
750
+ bits-per-dim.
751
+ This term can't be optimized, as it only depends on the encoder.
752
+ :param x_start: the [N x C x ...] tensor of inputs.
753
+ :return: a batch of [N] KL values (in bits), one per batch element.
754
+ """
755
+ batch_size = x_start.shape[0]
756
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
757
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
758
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
759
+ return mean_flat(kl_prior) / np.log(2.0)
760
+
761
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
762
+ """
763
+ Compute the entire variational lower-bound, measured in bits-per-dim,
764
+ as well as other related quantities.
765
+ :param model: the model to evaluate loss on.
766
+ :param x_start: the [N x C x ...] tensor of inputs.
767
+ :param clip_denoised: if True, clip denoised samples.
768
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
769
+ pass to the model. This can be used for conditioning.
770
+ :return: a dict containing the following keys:
771
+ - total_bpd: the total variational lower-bound, per batch element.
772
+ - prior_bpd: the prior term in the lower-bound.
773
+ - vb: an [N x T] tensor of terms in the lower-bound.
774
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
775
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
776
+ """
777
+ device = x_start.device
778
+ batch_size = x_start.shape[0]
779
+
780
+ vb = []
781
+ xstart_mse = []
782
+ mse = []
783
+ for t in list(range(self.num_timesteps))[::-1]:
784
+ t_batch = th.tensor([t] * batch_size, device=device)
785
+ noise = th.randn_like(x_start)
786
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
787
+ # Calculate VLB term at the current timestep
788
+ with th.no_grad():
789
+ out = self._vb_terms_bpd(
790
+ model,
791
+ x_start=x_start,
792
+ x_t=x_t,
793
+ t=t_batch,
794
+ clip_denoised=clip_denoised,
795
+ model_kwargs=model_kwargs,
796
+ )
797
+ vb.append(out["output"])
798
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
799
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
800
+ mse.append(mean_flat((eps - noise) ** 2))
801
+
802
+ vb = th.stack(vb, dim=1)
803
+ xstart_mse = th.stack(xstart_mse, dim=1)
804
+ mse = th.stack(mse, dim=1)
805
+
806
+ prior_bpd = self._prior_bpd(x_start)
807
+ total_bpd = vb.sum(dim=1) + prior_bpd
808
+ return {
809
+ "total_bpd": total_bpd,
810
+ "prior_bpd": prior_bpd,
811
+ "vb": vb,
812
+ "xstart_mse": xstart_mse,
813
+ "mse": mse,
814
+ }
815
+
816
+
817
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
818
+ """
819
+ Extract values from a 1-D numpy array for a batch of indices.
820
+ :param arr: the 1-D numpy array.
821
+ :param timesteps: a tensor of indices into the array to extract.
822
+ :param broadcast_shape: a larger shape of K dimensions with the batch
823
+ dimension equal to the length of timesteps.
824
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
825
+ """
826
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
827
+ while len(res.shape) < len(broadcast_shape):
828
+ res = res[..., None]
829
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
videosys/diffusion/respace.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
38
+ section_counts = [int(x) for x in section_counts.split(",")]
39
+ size_per = num_timesteps // len(section_counts)
40
+ extra = num_timesteps % len(section_counts)
41
+ start_idx = 0
42
+ all_steps = []
43
+ for i, section_count in enumerate(section_counts):
44
+ size = size_per + (1 if i < extra else 0)
45
+ if size < section_count:
46
+ raise ValueError(f"cannot divide section of {size} steps into {section_count}")
47
+ if section_count <= 1:
48
+ frac_stride = 1
49
+ else:
50
+ frac_stride = (size - 1) / (section_count - 1)
51
+ cur_idx = 0.0
52
+ taken_steps = []
53
+ for _ in range(section_count):
54
+ taken_steps.append(start_idx + round(cur_idx))
55
+ cur_idx += frac_stride
56
+ all_steps += taken_steps
57
+ start_idx += size
58
+ return set(all_steps)
59
+
60
+
61
+ class SpacedDiffusion(GaussianDiffusion):
62
+ """
63
+ A diffusion process which can skip steps in a base diffusion process.
64
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
65
+ original diffusion process to retain.
66
+ :param kwargs: the kwargs to create the base diffusion process.
67
+ """
68
+
69
+ def __init__(self, use_timesteps, **kwargs):
70
+ self.use_timesteps = set(use_timesteps)
71
+ self.timestep_map = []
72
+ self.original_num_steps = len(kwargs["betas"])
73
+
74
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
75
+ last_alpha_cumprod = 1.0
76
+ new_betas = []
77
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
78
+ if i in self.use_timesteps:
79
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
80
+ last_alpha_cumprod = alpha_cumprod
81
+ self.timestep_map.append(i)
82
+ kwargs["betas"] = np.array(new_betas)
83
+ super().__init__(**kwargs)
84
+
85
+ def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
86
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
87
+
88
+ def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
89
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
90
+
91
+ def condition_mean(self, cond_fn, *args, **kwargs):
92
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
93
+
94
+ def condition_score(self, cond_fn, *args, **kwargs):
95
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
96
+
97
+ def _wrap_model(self, model):
98
+ if isinstance(model, _WrappedModel):
99
+ return model
100
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
101
+
102
+ def _scale_timesteps(self, t):
103
+ # Scaling is done by the wrapped model.
104
+ return t
105
+
106
+
107
+ class _WrappedModel:
108
+ def __init__(self, model, timestep_map, original_num_steps):
109
+ self.model = model
110
+ self.timestep_map = timestep_map
111
+ # self.rescale_timesteps = rescale_timesteps
112
+ self.original_num_steps = original_num_steps
113
+
114
+ def __call__(self, x, ts, **kwargs):
115
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
116
+ new_ts = map_tensor[ts]
117
+ # if self.rescale_timesteps:
118
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
119
+ return self.model(x, new_ts, **kwargs)
videosys/diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
83
+ dist.all_gather(
84
+ batch_sizes,
85
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
86
+ )
87
+
88
+ # Pad all_gather batches to be the maximum batch size.
89
+ batch_sizes = [x.item() for x in batch_sizes]
90
+ max_bs = max(batch_sizes)
91
+
92
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
93
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
94
+ dist.all_gather(timestep_batches, local_ts)
95
+ dist.all_gather(loss_batches, local_losses)
96
+ timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
97
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
98
+ self.update_with_all_losses(timesteps, losses)
99
+
100
+ @abstractmethod
101
+ def update_with_all_losses(self, ts, losses):
102
+ """
103
+ Update the reweighting using losses from a model.
104
+ Sub-classes should override this method to update the reweighting
105
+ using losses from the model.
106
+ This method directly updates the reweighting without synchronizing
107
+ between workers. It is called by update_with_local_losses from all
108
+ ranks with identical arguments. Thus, it should have deterministic
109
+ behavior to maintain state across workers.
110
+ :param ts: a list of int timesteps.
111
+ :param losses: a list of float losses, one per timestep.
112
+ """
113
+
114
+
115
+ class LossSecondMomentResampler(LossAwareSampler):
116
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
117
+ self.diffusion = diffusion
118
+ self.history_per_term = history_per_term
119
+ self.uniform_prob = uniform_prob
120
+ self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
121
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
122
+
123
+ def weights(self):
124
+ if not self._warmed_up():
125
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
126
+ weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
127
+ weights /= np.sum(weights)
128
+ weights *= 1 - self.uniform_prob
129
+ weights += self.uniform_prob / len(weights)
130
+ return weights
131
+
132
+ def update_with_all_losses(self, ts, losses):
133
+ for t, loss in zip(ts, losses):
134
+ if self._loss_counts[t] == self.history_per_term:
135
+ # Shift out the oldest loss term.
136
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
137
+ self._loss_history[t, -1] = loss
138
+ else:
139
+ self._loss_history[t, self._loss_counts[t]] = loss
140
+ self._loss_counts[t] += 1
141
+
142
+ def _warmed_up(self):
143
+ return (self._loss_counts == self.history_per_term).all()
videosys/models/__init__.py ADDED
File without changes