diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f93d2aef27bd882b91956831666123ddcc24f907 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +t2v_enhanced/checkpoints/streaming_t2v.ckpt filter=lfs diff=lfs merge=lfs -text +__assets__/github/teaser/teaser_final.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..34f7adda41546ac5ca3cc4f8ca826b606ec78e4c --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +__pycache__/ +*.py[cod] + +.mlflow/ +/logs +/experiments +/t2v_enhanced/.mlflow +/t2v_enhanced/logs +/t2v_enhanced/slurm_logs +/t2v_enhanced/results + +t2v_enhanced/.mlflow +t2v_enhanced/logs +t2v_enhanced/slurm_logs +t2v_enhanced/lightning_logs +t2v_enhanced/results +t2v_enhanced/gradio_output \ No newline at end of file diff --git a/README.md b/README.md index ca0ffc743b2627965b3d98ddcb1b33451a564a96..7d5a33111dd20dd137d9559ca73ddd29f8139fc2 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,108 @@ sdk: gradio sdk_version: 4.25.0 app_file: app.py pinned: false -short_description: 'StreamingT2V: Consistent, Dynamic, and Extendable Long Video' +short_description: Consistent, Dynamic, and Extendable Long Video Generation fr --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file + + +# StreamingT2V + +This repository is the official implementation of [StreamingT2V](https://streamingt2v.github.io/). + + +**[StreamingT2V: Consistent, Dynamic, and Extendable Long Video Generation from Text](https://arxiv.org/abs/2403.14773)** +
+Roberto Henschel, +Levon Khachatryan, +Daniil Hayrapetyan, +Hayk Poghosyan, +Vahram Tadevosyan, +Zhangyang Wang, Shant Navasardyan, Humphrey Shi +
+ +[arXiv preprint](https://arxiv.org/abs/2403.14773) | [Video](https://twitter.com/i/status/1770909673463390414) | [Project page](https://streamingt2v.github.io/) + + +

+ +
+
+StreamingT2V is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation. It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality. Our demonstrations include successful examples of videos up to 1200 frames, spanning 2 minutes, and can be extended for even longer durations. Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos. +

+ +## News + +* [03/21/2024] Paper [StreamingT2V](https://arxiv.org/abs/2403.14773) released! +* [04/03/2024] Code and [model](https://huggingface.co/PAIR/StreamingT2V) released! + + +## Setup + + + +1. Clone this repository and enter: + +``` shell +git clone https://github.com/Picsart-AI-Research/StreamingT2V.git +cd StreamingT2V/ +``` +2. Install requirements using Python 3.10 and CUDA >= 11.6 +``` shell +conda create -n st2v python=3.10 +conda activate st2v +pip install -r requirements.txt +``` +3. (Optional) Install FFmpeg if it's missing on your system +``` shell +conda install conda-forge::ffmpeg +``` +4. Download the weights from [HF](https://huggingface.co/PAIR/StreamingT2V) and put them into the `t2v_enhanced/checkpoints` directory. + +--- + + +## Inference + + + +### For Text-to-Video + +``` shell +cd StreamingT2V/ +python inference.py --prompt="A cat running on the street" +``` +To use other base models add the `--base_model=AnimateDiff` argument. Use `python inference.py --help` for more options. + +### For Image-to-Video + +``` shell +cd StreamingT2V/ +python inference.py --image=../examples/underwater.png --base_model=SVD +``` + + + +## Results +Detailed results can be found in the [Project page](https://streamingt2v.github.io/). + +## License +Our code is published under the CreativeML Open RAIL-M license. + +We include [ModelscopeT2V](https://github.com/modelscope/modelscope), [AnimateDiff](https://github.com/guoyww/AnimateDiff), [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter) in the demo for research purposes and to demonstrate the flexibility of the StreamingT2V framework to include different T2V/I2V models. For commercial usage of such components, please refer to their original license. + + + + +## BibTeX +If you use our work in your research, please cite our publication: +``` +@article{henschel2024streamingt2v, + title={StreamingT2V: Consistent, Dynamic, and Extendable Long Video Generation from Text}, + author={Henschel, Roberto and Khachatryan, Levon and Hayrapetyan, Daniil and Poghosyan, Hayk and Tadevosyan, Vahram and Wang, Zhangyang and Navasardyan, Shant and Shi, Humphrey}, + journal={arXiv preprint arXiv:2403.14773}, + year={2024} +} +``` + + diff --git a/__assets__/github/teaser/teaser_final.png b/__assets__/github/teaser/teaser_final.png new file mode 100644 index 0000000000000000000000000000000000000000..4f80b1aadca80175b4c600ce1b536e60770e81e9 --- /dev/null +++ b/__assets__/github/teaser/teaser_final.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd4343418202d8aad2f08a65096482eb17527b784562a4e116da432aa22a30c5 +size 2656720 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a699bc5b3c2e987102ca93e0ee28d601e0a93d02 --- /dev/null +++ b/app.py @@ -0,0 +1,7 @@ +import gradio as gr + +def greet(name): + return "Hello " + name + "!!" + +iface = gr.Interface(fn=greet, inputs="text", outputs="text") +iface.launch() \ No newline at end of file diff --git a/examples/underwater.png b/examples/underwater.png new file mode 100644 index 0000000000000000000000000000000000000000..f95482afcf2220170b7481c4ccc64c11f70f8659 Binary files /dev/null and b/examples/underwater.png differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..1137246a1e0be4200746cc362aee2d6d949c6149 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[tool.poetry] +name = "t2v-enhanced" +version = "0.1.0" +description = "" +authors = ["Your Name "] +readme = "README.md" +packages = [{include = "t2v_enhanced"}] + +[tool.poetry.dependencies] +python = "^3.9" +torch = "^2.0.0" +omegaconf = "^2.3.0" +hydra-core = "^1.3.2" +pytorch-lightning = {extras = ["extra"], version = "^2.0.9"} +transformers = "^4.28.1" +torchmetrics = {extras = ["image"], version = "^0.11.4"} +mlflow = {extras = ["extras"], version = "^2.3.0"} +torchvision = "^0.15.1" +av = "^10.0.0" +rich = "^13.3.4" +albumentations = "^1.3.0" +datasets = "^2.12.0" +xformers = "^0.0.19" +kornia = "^0.7.0" +decord = "^0.6.0" +gdown = "^4.7.1" +pygifsicle = "^1.0.7" +ftfy = "^6.1.1" +regex = "^2023.6.3" +clip = {git = "https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33"} + + +[tool.poetry.group.dev.dependencies] +yapf = "^0.33.0" +autopep8 = "^2.0.2" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3a27e732fa30e9e26ed2d37ad94d1a7f583f6940 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,37 @@ +accelerate==0.28.0 +bitsandbytes==0.43.0 +transformers==4.39.3 +diffusers==0.27.2 +albumentations==1.3.0 +av==10.0.0 +boto3==1.26.115 +clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 +decord==0.6.0 +einops==0.7.0 +fastapi==0.103.2 +Flask==2.3.2 +gdown==4.7.1 +gradio==4.25.0 +gradio_client==0.15.0 +huggingface-hub==0.21.4 +jupyterlab==3.6.3 +omegaconf==2.3.0 +pandas==2.0.0 +pytorch-lightning==2.0.9 +scikit-image==0.20.0 +scikit-learn==1.2.2 +scipy==1.9.1 +seaborn==0.12.2 +-e . +torch==2.0.0 +torchdata==0.6.0 +torchvision==0.15.1 +tqdm==4.65.0 +xformers==0.0.19 +open-clip-torch==2.24.0 +jsonargparse==4.20.1 +fairscale==0.4.13 +rotary-embedding-torch==0.5.3 +easydict==1.13 +torchsde==0.2.6 +imageio[ffmpeg]==2.25.0 \ No newline at end of file diff --git a/t2v_enhanced/__init__.py b/t2v_enhanced/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc9d0cf1ed584275ad53788fd6e8a4ff82e14cf --- /dev/null +++ b/t2v_enhanced/__init__.py @@ -0,0 +1,4 @@ +from pathlib import Path + + +WORK_DIR = Path(__file__).resolve().parent diff --git a/t2v_enhanced/checkpoints/streaming_t2v.ckpt b/t2v_enhanced/checkpoints/streaming_t2v.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..1f8b996208697b90bdd8759ce840aef633c8a81b --- /dev/null +++ b/t2v_enhanced/checkpoints/streaming_t2v.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:821f94e00bb9e25b0b03ab5a37ac01d31a24df4573f2b7809c34c54c9712aa5c +size 25568849525 diff --git a/t2v_enhanced/configs/inference/inference_long_video.yaml b/t2v_enhanced/configs/inference/inference_long_video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b1f356428b1a2565846fb4f067eb2e3b2652886 --- /dev/null +++ b/t2v_enhanced/configs/inference/inference_long_video.yaml @@ -0,0 +1,37 @@ +trainer: + devices: '1' + num_nodes: 1 +model: + inference_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.InferenceParams + init_args: + num_inference_steps: 50 # number of inference steps + frame_rate: 3 + eta: 1.0 # eta used for DDIM sampler + guidance_scale: 7.5 # classifier free guidance scale + conditioning_type: fixed + start_from_real_input: false + n_autoregressive_generations: 6 # how many autoregressive generations + scheduler_cls: '' # we can load other models + unet_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.UNetParams + init_args: + use_standard_attention_processor: False + opt_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.OptimizerParams + init_args: + noise_generator: + class_path: t2v_enhanced.model.video_noise_generator.NoiseGenerator + init_args: + mode: vanilla # can be 'vanilla','mixed_noise', 'consistI2V' or 'mixed_noise_consistI2V' + alpha: 1.0 + shared_noise_across_chunks: True # if true, shared noise between all chunks of a video + forward_steps: 850 # number of DDPM forward steps + radius: [2,2,2] # radius for time, width and height +n_predictions: 300 +data: + class_path: t2v_enhanced.model.datasets.prompt_reader.PromptReader + init_args: + prompt_cfg: + type: file + content: /home/roberto.henschel/T2V-Enhanced/repo/training_code/t2v_enhanced/evaluation_prompts/prompts_long_eval.txt diff --git a/t2v_enhanced/configs/text_to_video/config.yaml b/t2v_enhanced/configs/text_to_video/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..042fc9a61cf539c6ca3987aa25d0a324baf0f0ab --- /dev/null +++ b/t2v_enhanced/configs/text_to_video/config.yaml @@ -0,0 +1,227 @@ +# pytorch_lightning==2.0.9 +seed_everything: 33 +trainer: + accelerator: auto + strategy: auto + devices: '8' + num_nodes: 1 + precision: 16-mixed + logger: null + callbacks: + - class_path: pytorch_lightning.callbacks.RichModelSummary + init_args: + max_depth: 1 + - class_path: pytorch_lightning.callbacks.RichProgressBar + init_args: + refresh_rate: 1 + leave: false + theme: + description: white + progress_bar: '#6206E0' + progress_bar_finished: '#6206E0' + progress_bar_pulse: '#6206E0' + batch_progress: white + time: grey54 + processing_speed: grey70 + metrics: white + console_kwargs: null + fast_dev_run: false + max_epochs: 5000 + min_epochs: null + max_steps: 2020000 + min_steps: null + max_time: null + limit_train_batches: null + limit_val_batches: 512 + limit_test_batches: null + limit_predict_batches: null + overfit_batches: 0.0 + val_check_interval: 8000 + check_val_every_n_epoch: 1 + num_sanity_val_steps: null + log_every_n_steps: 10 + enable_checkpointing: null + enable_progress_bar: null + enable_model_summary: null + accumulate_grad_batches: 8 + gradient_clip_val: 1 + gradient_clip_algorithm: norm + deterministic: null + benchmark: null + inference_mode: true + use_distributed_sampler: true + profiler: null + detect_anomaly: false + barebones: false + plugins: null + sync_batchnorm: false + reload_dataloaders_every_n_epochs: 0 + default_root_dir: null +model: + inference_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.InferenceParams + init_args: + width: 256 + height: 256 + video_length: 16 + guidance_scale: 7.5 + use_dec_scaling: true + frame_rate: 8 + num_inference_steps: 50 + eta: 1.0 + n_autoregressive_generations: 1 + mode: long_video + start_from_real_input: true + eval_loss_metrics: false + scheduler_cls: '' + negative_prompt: '' + conditioning_from_all_past: false + validation_samples: 80 + conditioning_type: last_chunk + result_formats: + - eval_gif + - gif + - mp4 + concat_video: true + opt_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.OptimizerParams + init_args: + learning_rate: 5.0e-05 + layers_config: + class_path: t2v_enhanced.model.requires_grad_setter.LayerConfig + init_args: + gradient_setup: + - - false + - - vae + - - false + - - text_encoder + - - false + - - image_encoder + - - true + - - resampler + - - true + - - unet + - - true + - - base_model + - - false + - - base_model + - transformer_in + - - false + - - base_model + - temp_attentions + - - false + - - base_model + - temp_convs + layers_config_base: null + use_warmup: false + warmup_steps: 10000 + warmup_start_factor: 1.0e-05 + learning_rate_spatial: 0.0 + use_8_bit_adam: false + noise_generator: null + noise_decomposition: null + perceptual_loss: false + noise_offset: 0.0 + split_opt_by_node: false + reset_prediction_type_to_eps: false + train_val_sampler_may_differ: true + measure_similarity: false + similarity_loss: false + similarity_loss_weight: 1.0 + loss_conditional_weight: 0.0 + loss_conditional_weight_convex: false + loss_conditional_change_after_step: 0 + mask_conditional_frames: false + sample_from_noise: true + mask_alternating: false + uncondition_freq: -1 + no_text_condition_control: false + inject_image_into_input: false + inject_at_T: false + resampling_steps: 1 + control_freq_in_resample: 1 + resample_to_T: false + adaptive_loss_reweight: false + load_resampler_from_ckpt: '' + skip_controlnet_branch: false + use_fps_conditioning: false + num_frame_embeddings_range: 16 + start_frame_training: 16 + start_frame_ctrl: 16 + load_trained_base_model_and_resampler_from_ckpt: '' + load_trained_controlnet_from_ckpt: '' + unet_params: + class_path: t2v_enhanced.model.pl_module_params_controlnet.UNetParams + init_args: + conditioning_embedding_out_channels: + - 32 + - 96 + - 256 + - 512 + ckpt_spatial_layers: '' + pipeline_repo: damo-vilab/text-to-video-ms-1.7b + unet_from_diffusers: true + spatial_latent_input: false + num_frame_conditioning: 1 + pipeline_class: t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline + frame_expansion: none + downsample_controlnet_cond: true + num_frames: 16 + pre_transformer_in_cond: false + num_tranformers: 1 + zero_conv_3d: false + merging_mode: addition + compute_only_conditioned_frames: false + condition_encoder: '' + zero_conv_mode: Identity + clean_model: true + merging_mode_base: attention_cross_attention + attention_mask_params: null + attention_mask_params_base: null + modelscope_input_format: true + temporal_self_attention_only_on_conditioning: false + temporal_self_attention_mask_included_itself: false + use_post_merger_zero_conv: false + weight_control_sample: 1.0 + use_controlnet_mask: false + random_mask_shift: false + random_mask: false + use_resampler: true + unet_from_pipe: false + unet_operates_on_2d: false + image_encoder: CLIP + use_standard_attention_processor: false + num_frames_before_chunk: 0 + resampler_type: single_frame + resampler_cls: t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder.ImgEmbContextResampler + resampler_merging_layers: 4 + image_encoder_obj: + class_path: t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder.FrozenOpenCLIPImageEmbedder + init_args: + arch: ViT-H-14 + version: laion2b_s32b_b79k + device: cuda + max_length: 77 + freeze: true + antialias: true + ucg_rate: 0.0 + unsqueeze_dim: false + repeat_to_max_len: false + num_image_crops: 0 + output_tokens: false + cfg_text_image: false + aggregation: last_out + resampler_random_shift: true + img_cond_alpha_per_frame: false + num_control_input_frames: 8 + use_image_encoder_normalization: false + use_of: false + ema_param: -1.0 + concat: false + use_image_tokens_main: true + use_image_tokens_ctrl: false +result_fol: results +exp_name: my_exp_name +run_name: my_run_name +scale_lr: false +matmul_precision: high diff --git a/t2v_enhanced/gradio_demo.py b/t2v_enhanced/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..e66a4400b5d2038d8abdda631a179766cf0b08bf --- /dev/null +++ b/t2v_enhanced/gradio_demo.py @@ -0,0 +1,189 @@ +# General +import os +from os.path import join as opj +import argparse +import datetime +from pathlib import Path +import torch +import gradio as gr +import tempfile +import yaml +from t2v_enhanced.model.video_ldm import VideoLDM + +# Utilities +from inference_utils import * +from model_init import * +from model_func import * + + +on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" +parser = argparse.ArgumentParser() +parser.add_argument('--public_access', action='store_true', default=True) +parser.add_argument('--where_to_log', type=str, default="gradio_output") +parser.add_argument('--device', type=str, default="cuda") +args = parser.parse_args() + + +Path(args.where_to_log).mkdir(parents=True, exist_ok=True) +result_fol = Path(args.where_to_log).absolute() +device = args.device + + +# -------------------------- +# ----- Configurations ----- +# -------------------------- +ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute() +cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True} + + +# -------------------------- +# ----- Initialization ----- +# -------------------------- +ms_model = init_modelscope(device) +# zs_model = init_zeroscope(device) +stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol) +msxl_model = init_v2v_model(cfg_v2v) + +inference_generator = torch.Generator(device="cuda") + + +# ------------------------- +# ----- Functionality ----- +# ------------------------- +def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance, where_to_log=result_fol): + now = datetime.datetime.now() + name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_") + + if num_frames == [] or num_frames is None: + num_frames = 56 + else: + num_frames = int(num_frames.split(" ")[0]) + + n_autoreg_gen = num_frames/8-8 + + inference_generator.manual_seed(seed) + short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device) + stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, name, stream_cli, stream_model) + video_path = opj(where_to_log, name+".mp4") + return video_path + +def enhance(prompt, input_to_enhance): + encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model) + return encoded_video + + +# -------------------------- +# ----- Gradio-Demo UI ----- +# -------------------------- +with gr.Blocks() as demo: + gr.HTML( + """ +
+

+ StreamingT2V +

+

+ Roberto Henschel1*, Levon Khachatryan1*, Daniil Hayrapetyan1*, Hayk Poghosyan1, Vahram Tadevosyan1, Zhangyang Wang1,2, Shant Navasardyan1, Humphrey Shi1,3 +

+

+ 1Picsart AI Resarch (PAIR), 2UT Austin, 3SHI Labs @ Georgia Tech, Oregon & UIUC +

+

+ *Equal Contribution +

+

+ [arXiv] + [GitHub] +

+

+ StreamingT2V is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation. + It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality. + Our demonstrations include successful examples of videos up to 1200 frames, spanning 2 minutes, and can be extended for even longer durations. + Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos. +

+
+ """) + + if on_huggingspace: + gr.HTML(""" +

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. +
+ + Duplicate Space +

""") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + with gr.Row(): + num_frames = gr.Dropdown(["24", "32", "40", "48", "56", "80 - only on local", "240 - only on local", "600 - only on local", "1200 - only on local", "10000 - only on local"], label="Number of Video Frames: Default is 56", info="For >80 frames use local workstation!") + with gr.Row(): + prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder="Ex: Dog running on the street.") + with gr.Row(): + image_stage1 = gr.Image(label='Image Prompt (only required for I2V base models)', show_label=True, scale=1, show_download_button=True) + with gr.Column(): + video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True) + with gr.Row(): + run_button_stage1 = gr.Button("Long Video Preview Generation") + + with gr.Row(): + with gr.Column(): + with gr.Accordion('Advanced options', open=False): + model_name_stage1 = gr.Dropdown( + choices=["T2V: ModelScope", "T2V: ZeroScope", "I2V: AnimateDiff"], + label="Base Model. Default is ModelScope", + info="Currently supports only ModelScope. We will add more options later!", + ) + model_name_stage2 = gr.Dropdown( + choices=["ModelScope-XL", "Another", "Another"], + label="Enhancement Model. Default is ModelScope-XL", + info="Currently supports only ModelScope-XL. We will add more options later!", + ) + n_prompt = gr.Textbox(label="Optional Negative Prompt", value='') + seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,) + + t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,) + image_guidance = gr.Slider(label='Image guidance scale', minimum=1, maximum=10, value=9.0, step=1.0) + + with gr.Column(): + with gr.Row(): + video_stage2 = gr.Video(label='Enhanced Long Video', show_label=True, interactive=False, height=473, show_download_button=True) + with gr.Row(): + run_button_stage2 = gr.Button("Long Video Enhancement") + ''' + ''' + gr.HTML( + """ +
+

+ Version: v1.0 +

+

+ Caution: + We would like the raise the awareness of users of this demo of its potential issues and concerns. + Like previous large foundation models, StreamingT2V could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingT2V can Inherit Its Imperfections. + So far, we keep all features available for research testing both to show the great potential of the StreamingT2V framework and to collect important feedback to improve the model in the future. + We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors. +

+

+ Biases and content acknowledgement: + Beware that StreamingT2V may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. + StreamingT2V in this demo is meant only for research purposes. +

+
+ """) + + inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance] + run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,) + + inputs_v2v = [prompt_stage1, video_stage1] + run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,) + + +if on_huggingspace: + demo.queue(max_size=20) + demo.launch(debug=True) +else: + _, _, link = demo.queue(api_open=False).launch(share=args.public_access) + print(link) \ No newline at end of file diff --git a/t2v_enhanced/inference.py b/t2v_enhanced/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ffbd16bd0bd6a17914711d6cebeb33e5418e693b --- /dev/null +++ b/t2v_enhanced/inference.py @@ -0,0 +1,82 @@ +# General +import os +from os.path import join as opj +import argparse +import datetime +from pathlib import Path +import torch +import gradio as gr +import tempfile +import yaml +from t2v_enhanced.model.video_ldm import VideoLDM + +# Utilities +from inference_utils import * +from model_init import * +from model_func import * + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--prompt', type=str, default="A cat running on the street", help="The prompt to guide video generation.") + parser.add_argument('--image', type=str, default="", help="Path to image conditioning.") + # parser.add_argument('--video', type=str, default="", help="Path to video conditioning.") + parser.add_argument('--base_model', type=str, default="ModelscopeT2V", help="Base model to generate first chunk from", choices=["ModelscopeT2V", "AnimateDiff", "SVD"]) + parser.add_argument('--num_frames', type=int, default=24, help="The number of video frames to generate.") + parser.add_argument('--negative_prompt', type=str, default="", help="The prompt to guide what to not include in video generation.") + parser.add_argument('--num_steps', type=int, default=50, help="The number of denoising steps.") + parser.add_argument('--image_guidance', type=float, default=9.0, help="The guidance scale.") + + parser.add_argument('--output_dir', type=str, default="results", help="Path where to save the generated videos.") + parser.add_argument('--device', type=str, default="cuda") + parser.add_argument('--seed', type=int, default=33, help="Random seed") + args = parser.parse_args() + + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + result_fol = Path(args.output_dir).absolute() + device = args.device + + + # -------------------------- + # ----- Configurations ----- + # -------------------------- + ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute() + cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True} + + + # -------------------------- + # ----- Initialization ----- + # -------------------------- + stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol) + if args.base_model == "ModelscopeT2V": + model = init_modelscope(device) + elif args.base_model == "AnimateDiff": + model = init_animatediff(device) + elif args.base_model == "SVD": + model = init_svd(device) + sdxl_model = init_sdxl(device) + + + inference_generator = torch.Generator(device="cuda") + + + # ------------------ + # ----- Inputs ----- + # ------------------ + now = datetime.datetime.now() + name = args.prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_") + + inference_generator = torch.Generator(device="cuda") + inference_generator.manual_seed(args.seed) + + if args.base_model == "ModelscopeT2V": + short_video = ms_short_gen(args.prompt, model, inference_generator) + elif args.base_model == "AnimateDiff": + short_video = ad_short_gen(args.prompt, model, inference_generator) + elif args.base_model == "SVD": + short_video = svd_short_gen(args.image, args.prompt, model, sdxl_model, inference_generator) + + n_autoreg_gen = args.num_frames // 8 - 8 + stream_long_gen(args.prompt, short_video, n_autoreg_gen, args.negative_prompt, args.seed, args.num_steps, args.image_guidance, name, stream_cli, stream_model) + video2video(args.prompt, opj(result_fol, name+".mp4"), result_fol, cfg_v2v, msxl_model) diff --git a/t2v_enhanced/inference_utils.py b/t2v_enhanced/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b706bee7d42089b4969d62d557f4f876f05d3e89 --- /dev/null +++ b/t2v_enhanced/inference_utils.py @@ -0,0 +1,101 @@ +# import argparse +import sys +from pathlib import Path +from pytorch_lightning.cli import LightningCLI +from PIL import Image + +# For streaming +import yaml +from copy import deepcopy +from typing import List, Optional +from jsonargparse.typing import restricted_string_type + + +# -------------------------------------- +# ----------- For Streaming ------------ +# -------------------------------------- +class CustomCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_argument("--result_fol", type=Path, + help="Set the path to the result folder", default="results") + parser.add_argument("--exp_name", type=str, help="Experiment name") + parser.add_argument("--run_name", type=str, + help="Current run name") + parser.add_argument("--prompts", type=Optional[List[str]]) + parser.add_argument("--scale_lr", type=bool, + help="Scale lr", default=False) + CodeType = restricted_string_type( + 'CodeType', '(medium)|(high)|(highest)') + parser.add_argument("--matmul_precision", type=CodeType) + parser.add_argument("--ckpt", type=Path,) + parser.add_argument("--n_predictions", type=int) + return parser + +def remove_value(dictionary, x): + for key, value in list(dictionary.items()): + if key == x: + del dictionary[key] + elif isinstance(value, dict): + remove_value(value, x) + return dictionary + +def legacy_transformation(cfg: yaml): + cfg = deepcopy(cfg) + cfg["trainer"]["devices"] = "1" + cfg["trainer"]['num_nodes'] = 1 + + if not "class_path" in cfg["model"]["inference_params"]: + cfg["model"]["inference_params"] = { + "class_path": "t2v_enhanced.model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]} + return cfg + + +# --------------------------------------------- +# ----------- For enhancement ----------- +# --------------------------------------------- +def add_margin(pil_img, top, right, bottom, left, color): + width, height = pil_img.size + new_width = width + right + left + new_height = height + top + bottom + result = Image.new(pil_img.mode, (new_width, new_height), color) + result.paste(pil_img, (left, top)) + return result + +def resize_to_fit(image, size): + W, H = size + w, h = image.size + if H / h > W / w: + H_ = int(h * W / w) + W_ = W + else: + W_ = int(w * H / h) + H_ = H + return image.resize((W_, H_)) + +def pad_to_fit(image, size): + W, H = size + w, h = image.size + pad_h = (H - h) // 2 + pad_w = (W - w) // 2 + return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) + +def resize_and_keep(pil_img): + myheight = 576 + hpercent = (myheight/float(pil_img.size[1])) + wsize = int((float(pil_img.size[0])*float(hpercent))) + pil_img = pil_img.resize((wsize, myheight)) + return pil_img + +def center_crop(pil_img): + width, height = pil_img.size + new_width = 576 + new_height = 576 + + left = (width - new_width)/2 + top = (height - new_height)/2 + right = (width + new_width)/2 + bottom = (height + new_height)/2 + + # Crop the center of the image + pil_img = pil_img.crop((left, top, right, bottom)) + return pil_img \ No newline at end of file diff --git a/t2v_enhanced/model/__init__.py b/t2v_enhanced/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/t2v_enhanced/model/callbacks.py b/t2v_enhanced/model/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..85f4814114dab8bcfde8afefa2b127cd570dd080 --- /dev/null +++ b/t2v_enhanced/model/callbacks.py @@ -0,0 +1,102 @@ + +from pathlib import Path +from pytorch_lightning import Callback +import os +import torch +from lightning_fabric.utilities.cloud_io import get_filesystem +from pytorch_lightning.cli import LightningArgumentParser +from pytorch_lightning import LightningModule, Trainer +from lightning_utilities.core.imports import RequirementCache +from omegaconf import OmegaConf + +_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache( + "jsonargparse[signatures]>=4.17.0") + +if _JSONARGPARSE_SIGNATURES_AVAILABLE: + import docstring_parser + from jsonargparse import ( + ActionConfigFile, + ArgumentParser, + class_from_function, + Namespace, + register_unresolvable_import_paths, + set_config_read_mode, + ) + + # Required until fix https://github.com/pytorch/pytorch/issues/74483 + register_unresolvable_import_paths(torch) + set_config_read_mode(fsspec_enabled=True) +else: + locals()["ArgumentParser"] = object + locals()["Namespace"] = object + + +class SaveConfigCallback(Callback): + """Saves a LightningCLI config to the log_dir when training starts. + + Args: + parser: The parser object used to parse the configuration. + config: The parsed configuration that will be saved. + config_filename: Filename for the config file. + overwrite: Whether to overwrite an existing config file. + multifile: When input is multiple config files, saved config preserves this structure. + + Raises: + RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ + + def __init__( + self, + parser: LightningArgumentParser, + config: Namespace, + log_dir: str, + config_filename: str = "config.yaml", + overwrite: bool = False, + multifile: bool = False, + + ) -> None: + self.parser = parser + self.config = config + self.config_filename = config_filename + self.overwrite = overwrite + self.multifile = multifile + self.already_saved = False + self.log_dir = log_dir + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + if self.already_saved: + return + + log_dir = self.log_dir + assert log_dir is not None + config_path = os.path.join(log_dir, self.config_filename) + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = fs.isfile( + config_path) if trainer.is_global_zero else False + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' + ) + + # save the file on rank 0 + if trainer.is_global_zero: + # save only on rank zero to avoid race conditions. + # the `log_dir` needs to be created as we rely on the logger to do it usually + # but it hasn't logged anything at this point + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + self.already_saved = True + trainer.logger.log_hyperparams(OmegaConf.load(config_path)) + + # broadcast so that all ranks are in sync on future calls to .setup() + self.already_saved = trainer.strategy.broadcast(self.already_saved) diff --git a/t2v_enhanced/model/datasets/prompt_reader.py b/t2v_enhanced/model/datasets/prompt_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..4891617bbe121b9da52528c8f5c56224e1ca4cbf --- /dev/null +++ b/t2v_enhanced/model/datasets/prompt_reader.py @@ -0,0 +1,80 @@ +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import EVAL_DATALOADERS + +from t2v_enhanced.model.datasets.video_dataset import Annotations +import json + + +class ConcatDataset(torch.utils.data.Dataset): + def __init__(self, datasets): + self.datasets = datasets + self.model_id = datasets["reconstruction_dataset"].model_id + + def __getitem__(self, idx): + sample = {ds: self.datasets[ds].__getitem__( + idx) for ds in self.datasets} + return sample + + def __len__(self): + return min(len(self.datasets[d]) for d in self.datasets) + + +class CustomPromptsDataset(torch.utils.data.Dataset): + + def __init__(self, prompt_cfg: Dict[str, str]): + super().__init__() + + if prompt_cfg["type"] == "prompt": + self.prompts = [prompt_cfg["content"]] + elif prompt_cfg["type"] == "file": + file = Path(prompt_cfg["content"]) + if file.suffix == ".npy": + self.prompts = np.load(file.as_posix()) + elif file.suffix == ".txt": + with open(prompt_cfg["content"]) as f: + lines = [line.rstrip() for line in f] + self.prompts = lines + elif file.suffix == ".json": + with open(prompt_cfg["content"],"r") as file: + metadata = json.load(file) + if "videos_root" in prompt_cfg: + videos_root = Path(prompt_cfg["videos_root"]) + video_path = [str(videos_root / sample["page_dir"] / + f"{sample['videoid']}.mp4") for sample in metadata] + else: + video_path = [str(sample["page_dir"] / + f"{sample['videoid']}.mp4") for sample in metadata] + self.prompts = [sample["prompt"] for sample in metadata] + self.video_path = video_path + + + + + transformed_prompts = [] + for prompt in self.prompts: + transformed_prompts.append( + Annotations.clean_prompt(prompt)) + self.prompts = transformed_prompts + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + output = {"prompt": self.prompts[index]} + if hasattr(self,"video_path"): + output["video"] = self.video_path[index] + return output + + +class PromptReader(pl.LightningDataModule): + def __init__(self, prompt_cfg: Dict[str, str]): + super().__init__() + self.predict_dataset = CustomPromptsDataset(prompt_cfg) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + return torch.utils.data.DataLoader(self.predict_dataset, batch_size=1, pin_memory=False, shuffle=False, drop_last=False) diff --git a/t2v_enhanced/model/datasets/video_dataset.py b/t2v_enhanced/model/datasets/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..21c752460db8c8eebba2d4d89cf5f6fd8228b456 --- /dev/null +++ b/t2v_enhanced/model/datasets/video_dataset.py @@ -0,0 +1,57 @@ +from tqdm import tqdm +from einops import repeat +from diffusers import DiffusionPipeline +from decord import VideoReader, cpu +import torchvision +import torch +import numpy as np +import decord +import albumentations as album +import math +import random +from abc import abstractmethod +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Union +from PIL import Image +import json +Image.MAX_IMAGE_PIXELS = None + +decord.bridge.set_bridge("torch") + +class Annotations(): + + def __init__(self, + annotation_cfg: Dict) -> None: + self.annotation_cfg = annotation_cfg + + # TODO find all special characters + + @staticmethod + def process_string(string): + for special_char in [".", ",", ":"]: + result = "" + i = 0 + while i < len(string): + if string[i] == special_char: + if i > 0 and i < len(string) - 1 and string[i-1].isalpha() and string[i+1].isalpha(): + result += special_char+" " + else: + result += special_char + else: + result += string[i] + i += 1 + string = result + string = result + return result + + @staticmethod + def clean_prompt(prompt): + prompt = " ".join(prompt.split()) + prompt = prompt.replace(" , ", ", ") + prompt = prompt.replace(" . ", ". ") + prompt = prompt.replace(" : ", ": ") + prompt = Annotations.process_string(prompt) + return prompt + # return " ".join(prompt.split()) + diff --git a/t2v_enhanced/model/diffusers_conditional/__init__.py b/t2v_enhanced/model/diffusers_conditional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/t2v_enhanced/model/diffusers_conditional/models/__init__.py b/t2v_enhanced/model/diffusers_conditional/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/__init__.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/attention.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c714e5f22cae06e54fccf11e75633c63c4cb6f --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/attention.py @@ -0,0 +1,291 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.import_utils import is_xformers_available +# from diffusers.models.attention_processor import Attention +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import Attention +from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention +from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + is_spatial_attention: bool = False, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + use_image_embedding: bool = False, + unet_params=None, + ): + super().__init__() + + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + is_spatial_attention=is_spatial_attention, + use_image_embedding=use_image_embedding, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + is_spatial_attention=is_spatial_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2( + hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + + + diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/attention_processor.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..758e3025d773b952b0089bfefcbe2a7b23250851 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/attention_processor.py @@ -0,0 +1,444 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from einops import repeat +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_xformers_available + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + is_spatial_attention: bool, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + processor: Optional["AttnProcessor"] = None, + use_image_embedding: bool = False, + unet_params=None, + ): + super().__init__() + inner_dim = dim_head * heads + self.cross_attention_mode = cross_attention_dim is not None + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.is_spatial_attention = is_spatial_attention + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.train_image_cond_weight = use_image_embedding + self.use_image_embedding = use_image_embedding + + self.scale = dim_head**-0.5 if scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + embed_dim = 93 + if self.cross_attention_mode and self.is_spatial_attention and self.use_image_embedding: + self.conv = torch.nn.Conv1d(embed_dim, 77, kernel_size=3, padding="same") + self.conv_ln = nn.LayerNorm(1024) + self.register_parameter("alpha", nn.Parameter(torch.tensor(0.))) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr( + F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) + ) + + if use_memory_efficient_attention_xformers: + if self.added_kv_proj_dim is not None: + # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when" + " `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + processor = LoRAAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = AttnProcessor() + + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + processor = AttnProcessor() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, + head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor, out_dim=3): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, + head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + ( + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" + " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" + " `prepare_attention_mask` when preparing the attention_mask." + ), + ) + batch_size = 1 + + head_size = self.heads + if attention_mask is None: + return attention_mask + + if attention_mask.shape[-1] != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros( + padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad( + attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + + + +class AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + + +AttentionProcessor = Union[ + AttnProcessor2_0, +] diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/conditioning.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/conditioning.py new file mode 100644 index 0000000000000000000000000000000000000000..981b103b6a19b9842d0b250326288d1efee875fc --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/conditioning.py @@ -0,0 +1,100 @@ +import diffusers +from diffusers.models.transformer_temporal import TransformerTemporalModel, TransformerTemporalModelOutput +import torch.nn as nn +from einops import rearrange +from diffusers.models.attention_processor import Attention +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention +from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal_crossattention import TransformerTemporalModel as TransformerTemporalModelCrossAttn +import torch + + +class CrossAttention(nn.Module): + + def __init__(self, input_channels, attention_head_dim, norm_num_groups=32): + super().__init__() + self.attention = Attention( + query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False) + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(input_channels, input_channels) + self.proj_out = nn.Linear(input_channels, input_channels) + + def forward(self, hidden_state, encoder_hidden_states, num_frames): + h, w = hidden_state.shape[2], hidden_state.shape[3] + hidden_state_norm = rearrange( + hidden_state, "(B F) C H W -> B C F H W", F=num_frames) + hidden_state_norm = self.norm(hidden_state_norm) + hidden_state_norm = rearrange( + hidden_state_norm, "B C F H W -> (B H W) F C") + hidden_state_norm = self.proj_in(hidden_state_norm) + attn = self.attention(hidden_state_norm, + encoder_hidden_states=encoder_hidden_states, + attention_mask=None, + ) + # proj_out + + residual = self.proj_out(attn) + + residual = rearrange( + residual, "(B H W) F C -> (B F) C H W", H=h, W=w) + output = hidden_state + residual + return TransformerTemporalModelOutput(sample=output) + + +class ConditionalModel(nn.Module): + + def __init__(self, input_channels, conditional_model: str, attention_head_dim=64): + super().__init__() + num_layers = 1 + if "_layers_" in conditional_model: + config = conditional_model.split("_layers_") + conditional_model = config[0] + num_layers = int(config[1]) + + if conditional_model == "self_cross_transformer": + self.temporal_transformer = TransformerTemporalModel(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels, + double_self_attention=False, cross_attention_dim=input_channels) + elif conditional_model == "cross_transformer": + self.temporal_transformer = TransformerTemporalModelCrossAttn(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels, + double_self_attention=False, cross_attention_dim=input_channels, num_layers=num_layers) + elif conditional_model == "cross_attention": + self.temporal_transformer = CrossAttention( + input_channels=input_channels, attention_head_dim=attention_head_dim) + elif conditional_model == "test_conv": + self.temporal_transformer = nn.Conv2d( + input_channels, input_channels, kernel_size=1) + else: + raise NotImplementedError( + f"mode {conditional_model} not implemented") + if conditional_model != "test_conv": + nn.init.zeros_(self.temporal_transformer.proj_out.weight) + nn.init.zeros_(self.temporal_transformer.proj_out.bias) + else: + nn.init.zeros_(self.temporal_transformer.weight) + nn.init.zeros_(self.temporal_transformer.bias) + self.conditional_model = conditional_model + + def forward(self, sample, conditioning, num_frames=None): + + assert conditioning.ndim == 5 + assert sample.ndim == 5 + if self.conditional_model != "test_conv": + conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C") + + num_frames = sample.shape[1] + + sample = rearrange(sample, "B F C H W -> (B F) C H W") + + sample = self.temporal_transformer( + sample, encoder_hidden_states=conditioning, num_frames=num_frames).sample + + sample = rearrange( + sample, "(B F) C H W -> B F C H W", F=num_frames) + else: + + conditioning = rearrange(conditioning, "B F C H W -> (B F) C H W") + f = sample.shape[1] + sample = rearrange(sample, "B F C H W -> (B F) C H W") + sample = sample + self.temporal_transformer(conditioning) + sample = rearrange(sample, "(B F) C H W -> B F C H W", F=f) + return sample diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/controlnet.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7f20f5d2f85254c2d5ce097075b82c7a753994 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/controlnet.py @@ -0,0 +1,865 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c +) +# from diffusers.models.unet_3d_condition import UNet3DConditionModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal import TransformerTemporalModel +from t2v_enhanced.model.layers.conv_channel_extension import Conv2D_SubChannels +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class Merger(nn.Module): + def __init__(self, n_frames_condition: int = 8, n_frames_sample: int = 16, merge_mode: str = "addition", input_channels=0, frame_expansion="last_frame") -> None: + super().__init__() + self.merge_mode = merge_mode + self.n_frames_condition = n_frames_condition + self.n_frames_sample = n_frames_sample + self.frame_expansion = frame_expansion + + if merge_mode.startswith("attention"): + self.attention = ConditionalModel(input_channels=input_channels, + conditional_model=merge_mode.split("attention_")[1]) + + def forward(self, x, condition_signal): + x = rearrange(x, "(B F) C H W -> B F C H W", F=self.n_frames_sample) + + condition_signal = rearrange( + condition_signal, "(B F) C H W -> B F C H W", B=x.shape[0]) + + if x.shape[1] - condition_signal.shape[1] > 0: + if self.frame_expansion == "last_frame": + fillup_latent = repeat( + condition_signal[:, -1], "B C H W -> B F C H W", F=x.shape[1] - condition_signal.shape[1]) + elif self.frame_expansion == "zero": + fillup_latent = torch.zeros( + (x.shape[0], self.n_frames_sample-self.n_frames_condition, *x.shape[2:]), device=x.device, dtype=x.dtype) + + if self.frame_expansion != "none": + condition_signal = torch.cat( + [condition_signal, fillup_latent], dim=1) + + if self.merge_mode == "addition": + out = x + condition_signal + elif self.merge_mode.startswith("attention"): + out = self.attention(x, condition_signal) + out = rearrange(out, "B F C H W -> (B F) C H W") + return out + + +class ZeroConv(nn.Module): + def __init__(self, channels: int, mode: str = "2d", num_frames: int = 8, zero_init=True): + super().__init__() + mode_parts = mode.split("_") + if len(mode_parts) > 1 and mode_parts[1] == "noinit": + zero_init = False + + if mode.startswith("2d"): + model = nn.Conv2d( + channels, channels, kernel_size=1) + model = zero_module(model, reset=zero_init) + elif mode.startswith("3d"): + model = ZeroConv3D(num_frames=num_frames, + channels=channels, zero_init=zero_init) + elif mode == "Identity": + model = nn.Identity() + self.model = model + + def forward(self, x): + return self.model(x) + + + + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + # TODO why not GAUSSIAN used? + # TODO why not 4x4 kernel? + # TODO why not 2 x2 stride? + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + downsample: bool = True, + final_3d_conv: bool = False, + num_frame_conditioning: int = 8, + num_frames: int = 16, + zero_init: bool = True, + use_controlnet_mask: bool = False, + use_normalization: bool = False, + ): + super().__init__() + self.num_frame_conditioning = num_frame_conditioning + self.num_frames = num_frames + self.final_3d_conv = final_3d_conv + self.conv_in = nn.Conv2d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + if final_3d_conv: + print("USING 3D CONV in ControlNET") + + self.blocks = nn.ModuleList([]) + if use_normalization: + self.norms = nn.ModuleList([]) + self.use_normalization = use_normalization + + stride = 2 if downsample else 1 + if use_normalization: + res = 256 # HARD-CODED Resolution! + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + if use_normalization: + self.norms.append(nn.LayerNorm((channel_in, res, res))) + self.blocks.append( + nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride)) + if use_normalization: + res = res // 2 + self.norms.append(nn.LayerNorm((channel_out, res, res))) + + if not final_3d_conv: + self.conv_out = zero_module( + nn.Conv2d( + block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1), reset=zero_init + ) + else: + self.conv_temp = zero_module(TemporalConvLayer_Custom( + num_frame_conditioning, num_frames, dropout=0.0), reset=zero_init) + self.conv_out = nn.Conv2d( + block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1) + # self.conv_temp = zero_module(nn.Conv3d( + # num_frame_conditioning, num_frames, kernel_size=3, padding=1) + # ) + + def forward(self, conditioning, vq_gan=None, controlnet_mask=None): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + if self.use_normalization: + for block, norm in zip(self.blocks, self.norms): + embedding = block(embedding) + embedding = norm(embedding) + embedding = F.silu(embedding) + else: + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + if controlnet_mask is not None: + embedding = rearrange( + embedding, "(B F) C H W -> F B C H W", F=self.num_frames) + controlnet_mask_expanded = controlnet_mask[:, :, None, None, None] + controlnet_mask_expanded = rearrange( + controlnet_mask_expanded, "B F C W H -> F B C W H") + masked_embedding = controlnet_mask_expanded * embedding + embedding = rearrange(masked_embedding, "F B C H W -> (B F) C H W") + controlnet_mask_expanded = rearrange( + controlnet_mask_expanded, "F B C H W -> (B F) C H W") + # controlnet_mask_expanded = repeat(controlnet_mask_expanded,"B C W H -> B (C x) W H",x=embedding.shape[1]) + controlnet_mask_expanded = repeat( + controlnet_mask_expanded, "B C W H -> B C (W y) H", y=embedding.shape[2]) + controlnet_mask_expanded = repeat( + controlnet_mask_expanded, "B C W H -> B C W (H z)", z=embedding.shape[3]) + + embedding = torch.cat([embedding, controlnet_mask_expanded], dim=1) + + embedding = self.conv_out(embedding) + if self.final_3d_conv: + # embedding = F.silu(embedding) + embedding = rearrange( + embedding, "(b f) c h w -> b f c h w", f=self.num_frame_conditioning) + embedding = self.conv_temp(embedding) + embedding = rearrange(embedding, "b f c h w -> (b f) c h w") + + return embedding + +class ControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + in_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = ( + 16, 32, 96, 256), + global_pool_conditions: bool = False, + downsample_controlnet_cond: bool = True, + frame_expansion: str = "zero", + condition_encoder: str = "", + num_frames: int = 16, + num_frame_conditioning: int = 8, + num_tranformers: int = 1, + vae=None, + merging_mode: str = "addition", + zero_conv_mode: str = "2d", + use_controlnet_mask: bool = False, + use_image_embedding: bool = False, + use_image_encoder_normalization: bool = False, + unet_params=None, + ): + super().__init__() + self.gradient_checkpointing = False + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + self.use_image_tokens = unet_params.use_image_tokens_ctrl + self.image_encoder_name = type(unet_params.image_encoder).__name__ + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + '''Conv2D_SubChannels + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + ''' + self.conv_in = Conv2D_SubChannels( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding( + num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + conditioning_channels = 3 if downsample_controlnet_cond else 4 + # control net conditioning embedding + + if condition_encoder == "temp_conv_vq": + controlnet_cond_embedding = ControlNetConditioningEmbeddingVQ( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=4, + block_out_channels=conditioning_embedding_out_channels, + downsample=False, + + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + num_tranformers=num_tranformers, + # zero_init=not merging_mode.startswith("attention"), + ) + elif condition_encoder == "vq": + controlnet_cond_embedding = ControlNetConditioningOptVQ(vq=vae, + conditioning_embedding_channels=block_out_channels[ + 0], + conditioning_channels=4, + block_out_channels=conditioning_embedding_out_channels, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + ) + + else: + controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=conditioning_channels, + block_out_channels=conditioning_embedding_out_channels, + downsample=downsample_controlnet_cond, + final_3d_conv=condition_encoder.endswith("3DConv"), + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + # zero_init=not merging_mode.startswith("attention") + use_controlnet_mask=use_controlnet_mask, + use_normalization=use_image_encoder_normalization, + ) + self.use_controlnet_mask = use_controlnet_mask + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + # conv_in + self.merger = Merger(n_frames_sample=num_frames, n_frames_condition=num_frame_conditioning, + merge_mode=merging_mode, input_channels=block_out_channels[0], frame_expansion=frame_expansion) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [ + only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + self.controlnet_down_blocks.append( + ZeroConv(channels=output_channel, mode=zero_conv_mode, num_frames=num_frames)) + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + self.controlnet_down_blocks.append( + ZeroConv(channels=output_channel, mode=zero_conv_mode, num_frames=num_frames)) + + if not is_final_block: + self.controlnet_down_blocks.append( + ZeroConv(channels=output_channel, mode=zero_conv_mode, num_frames=num_frames)) + + # mid + mid_block_channel = block_out_channels[-1] + + self.controlnet_mid_block = ZeroConv( + channels=mid_block_channel, mode=zero_conv_mode, num_frames=num_frames) + + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.controlnet_cond_embedding = controlnet_cond_embedding + self.num_frames = num_frames + self.num_frame_conditioning = num_frame_conditioning + + @classmethod + def from_unet( + cls, + unet: UNet3DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = ( + 16, 32, 96, 256), + load_weights_from_unet: bool = True, + downsample_controlnet_cond: bool = True, + num_frames: int = 16, + num_frame_conditioning: int = 8, + frame_expansion: str = "zero", + num_tranformers: int = 1, + vae=None, + zero_conv_mode: str = "2d", + merging_mode: str = "addition", + # [spatial,spatial_3DConv,temp_conv_vq] + condition_encoder: str = "spatial_3DConv", + use_controlnet_mask: bool = False, + use_image_embedding: bool = False, + use_image_encoder_normalization: bool = False, + unet_params=None, + ** kwargs, + ): + r""" + Instantiate Controlnet class from UNet3DConditionModel. + + Parameters: + unet (`UNet3DConditionModel`): + UNet model which weights are copied to the ControlNet. Note that all configuration options are also + copied where applicable. + """ + controlnet = cls( + in_channels=unet.config.in_channels, + down_block_types=unet.config.down_block_types, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + downsample_controlnet_cond=downsample_controlnet_cond, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + frame_expansion=frame_expansion, + num_tranformers=num_tranformers, + vae=vae, + zero_conv_mode=zero_conv_mode, + merging_mode=merging_mode, + condition_encoder=condition_encoder, + use_controlnet_mask=use_controlnet_mask, + use_image_embedding=use_image_embedding, + use_image_encoder_normalization=use_image_encoder_normalization, + unet_params=unet_params, + + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.transformer_in.load_state_dict( + unet.transformer_in.state_dict()) + controlnet.time_embedding.load_state_dict( + unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict( + unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict( + unet.down_blocks.state_dict(), strict=False) # can be that the controlnet model does not use image clip encoding + controlnet.mid_block.load_state_dict( + unet.mid_block.state_dict(), strict=False) + + return controlnet + + @property + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors( + f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor( + f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * \ + [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError( + f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D)): + module.gradient_checkpointing = value + + # TODO ADD WEIGHT CONTROL + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + weight_control: float = 1.0, + weight_control_sample: float = 1.0, + controlnet_mask: Optional[torch.Tensor] = None, + vq_gan=None, + ) -> Union[ControlNetOutput, Tuple]: + # check channel order + # TODO SET ATTENTION MASK And WEIGHT CONTROL as in CONTROLNET.PY + ''' + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + ''' + # assert controlnet_mask is None, "Controlnet Mask not implemented yet for clean model" + # 1. time + + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor( + [timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + sample = sample[:, :, :self.num_frames] + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + if not self.use_image_tokens and encoder_hidden_states.shape[1] > 77: + encoder_hidden_states = encoder_hidden_states[:, :77] + + if encoder_hidden_states.shape[1] > 77: + # assert ( + # encoder_hidden_states.shape[1]-77) % num_frames == 0, f"Encoder shape {encoder_hidden_states.shape}. Num frames = {num_frames}" + context_text, context_img = encoder_hidden_states[:, + :77, :], encoder_hidden_states[:, 77:, :] + context_text = context_text.repeat_interleave( + repeats=num_frames, dim=0) + + if self.image_encoder_name == "FrozenOpenCLIPImageEmbedder": + context_img = context_img.repeat_interleave( + repeats=num_frames, dim=0) + else: + context_img = rearrange( + context_img, 'b (t l) c -> (b t) l c', t=num_frames) + + encoder_hidden_states = torch.cat( + [context_text, context_img], dim=1) + else: + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + repeats=num_frames, dim=0) + + # print(f"ctrl with tokens = {encoder_hidden_states.shape[1]}") + ''' + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + repeats=num_frames, dim=0) + ''' + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape( + (sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding( + controlnet_cond, vq_gan=vq_gan, controlnet_mask=controlnet_mask) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c( + self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in( + sample, num_frames=num_frames, attention_mask=attention_mask).sample + + sample = self.merger(sample * weight_control_sample, + weight_control * controlnet_cond) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + \ + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + # 0.1 to 1.0 + scales = torch.logspace(-1, 0, len(down_block_res_samples) + + 1, device=sample.device) + + scales = scales * conditioning_scale + down_block_res_samples = [ + sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * \ + scales[-1] # last one + else: + down_block_res_samples = [ + sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean( + mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + + +def zero_module(module, reset=True): + if reset: + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/cross_attention.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..06389a35dd71dbc4679b4efc4e304799932e92b2 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/cross_attention.py @@ -0,0 +1,30 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.import_utils import is_xformers_available +# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..49a46f792157449b08450d961b3183fdf4b85512 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py @@ -0,0 +1,211 @@ +import math +from typing import Any, Mapping +import torch +import torch.nn as nn +import kornia +import open_clip +from transformers import AutoImageProcessor, AutoModel +from transformers.models.bit.image_processing_bit import BitImageProcessor +from einops import rearrange, repeat +# FFN +# from mamba_ssm import Mamba + + + +class ImgEmbContextResampler(nn.Module): + + def __init__( + self, + inner_dim=1280, + cross_attention_dim=1024, + expansion_factor=16, + **kwargs, + ): + super().__init__() + self.context_embedding = nn.Sequential( + nn.Linear(cross_attention_dim, inner_dim), + nn.SiLU(), + nn.Linear(inner_dim, cross_attention_dim * expansion_factor), + ) + self.expansion_factor = expansion_factor + self.cross_attention_dim = cross_attention_dim + + def forward(self, x, batch_size=0): + if x.ndim == 2: + x = rearrange(x, "(B F) C -> B F C", B=batch_size) + assert x.ndim == 3 + x = torch.mean(x, dim=1, keepdim=True) + x = self.context_embedding(x) + x = x.view(-1, self.expansion_factor, self.cross_attention_dim) + return x + + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + self.embedding_dim = -1 + self.num_tokens = -1 + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + antialias=True, + ucg_rate=0.0, + unsqueeze_dim=False, + repeat_to_max_len=False, + num_image_crops=0, + output_tokens=False, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.transformer + self.model = model + self.max_crops = num_image_crops + self.pad_to_max_len = self.max_crops > 0 + self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + self.unsqueeze_dim = unsqueeze_dim + self.stored_batch = None + self.model.visual.output_tokens = output_tokens + self.output_tokens = output_tokens + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + tokens = None + if self.output_tokens: + z, tokens = z[0], z[1] + z = z.to(image.dtype) + if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): + z = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) + )[:, None] + * z + ) + if tokens is not None: + tokens = ( + expand_dims_like( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(tokens.shape[0], device=tokens.device) + ), + tokens, + ) + * tokens + ) + if self.unsqueeze_dim: + z = z[:, None, :] + if self.output_tokens: + assert not self.repeat_to_max_len + assert not self.pad_to_max_len + return tokens, z + if self.repeat_to_max_len: + if z.dim() == 2: + z_ = z[:, None, :] + else: + z_ = z + return repeat(z_, "b 1 d -> b n d", n=self.max_length), z + elif self.pad_to_max_len: + assert z.dim() == 3 + z_pad = torch.cat( + ( + z, + torch.zeros( + z.shape[0], + self.max_length - z.shape[1], + z.shape[2], + device=z.device, + ), + ), + 1, + ) + return z_pad, z_pad[:, 0, ...] + return z + + def encode_with_vision_transformer(self, img): + # if self.max_crops > 0: + # img = self.preprocess_by_cropping(img) + if img.dim() == 5: + assert self.max_crops == img.shape[1] + img = rearrange(img, "b n c h w -> (b n) c h w") + img = self.preprocess(img) + if not self.output_tokens: + assert not self.model.visual.output_tokens + x = self.model.visual(img) + tokens = None + else: + assert self.model.visual.output_tokens + x, tokens = self.model.visual(img) + if self.max_crops > 0: + x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) + # drop out between 0 and all along the sequence axis + x = ( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) + ) + * x + ) + if tokens is not None: + tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) + print( + f"You are running very experimental token-concat in {self.__class__.__name__}. " + f"Check what you are doing, and then remove this message." + ) + if self.output_tokens: + return x, tokens + return x + + def encode(self, text): + return self(text) \ No newline at end of file diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/mask_generator.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3333383b224264ff52e2c1c71c3a9178291d5b64 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/mask_generator.py @@ -0,0 +1,27 @@ +from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams +import torch + + +class MaskGenerator(): + + def __init__(self, params: AttentionMaskParams, num_frame_conditioning, num_frames): + self.params = params + self.num_frame_conditioning = num_frame_conditioning + self.num_frames = num_frames + def get_mask(self, precision, device): + + params = self.params + if params.temporal_self_attention_only_on_conditioning: + with torch.no_grad(): + attention_mask = torch.zeros((1, self.num_frames, self.num_frames), dtype=torch.float16 if precision.startswith( + "16") else torch.float32, device=device) + for frame in range(self.num_frame_conditioning, self.num_frames): + attention_mask[:, frame, + self.num_frame_conditioning:] = float("-inf") + if params.temporal_self_attention_mask_included_itself: + attention_mask[:, frame, frame] = 0 + if params.temp_attend_on_uncond_include_past: + attention_mask[:, frame, :frame] = 0 + else: + attention_mask = None + return attention_mask diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..3a107321676a50f2cb9ae5bf185f9490d1e97e8f --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py @@ -0,0 +1,925 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet3DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.text_to_video_synthesis import TextToVideoSDPipelineOutput +from einops import rearrange + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], output_type="list") -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + if output_type == "list": + # prepare a list of indvidual (consecutive frames) + images = images.unbind(dim=0) + images = [(image.cpu().numpy() * 255).astype("uint8") + for image in images] # f h w c + elif output_type == "pt": + pass + return images + + +class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + controlnet, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** ( + len(self.vae.config.block_out_channels) - 1) + + def prepare_image( + self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance, cfg_text_image=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize( + (width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + image_vq_enc = self.vae.encode(rearrange( + image, "B F C W H -> (B F) C W H")).latent_dist.sample() * self.vae.config.scaling_factor + image_vq_enc = rearrange( + image_vq_enc, "(B F) C W H -> B F C W H", B=image_batch_size) + if do_classifier_free_guidance: + if cfg_text_image: + image = torch.cat([torch.zeros_like(image), image], dim=0) + else: + image = torch.cat([image] * 2) + # image_vq_enc = torch.cat([image_vq_enc] * 2) + + return image, image_vq_enc + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded + to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a + submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError( + "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + # otherwise we don't see the memory savings (but they probably exist) + torch.cuda.empty_cache() + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + # otherwise we don't see the memory savings (but they probably exist) + torch.cuda.empty_cache() + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + img_cond: Optional[torch.FloatTensor] = None, + img_cond_unc: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1) + max_length = prompt_embeds.shape[1] + if img_cond is not None: + if img_cond.ndim == 2: + img_cond = img_cond.unsqueeze(1) + prompt_embeds = torch.cat([prompt_embeds, img_cond], dim=1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt( + uncond_tokens, self.tokenizer) + + # max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) + + if img_cond_unc is not None: + if img_cond_unc.ndim == 2: + img_cond_unc = img_cond_unc.unsqueeze(1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds, img_cond_unc], dim=1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature( + self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance( + callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if hasattr(self, "noise_generator"): + latents = self.noise_generator.sample_noise( + shape=shape, generator=generator, device=device, dtype=dtype) + elif latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_noise_generator(self, noise_generator): + if noise_generator is not None and noise_generator.mode != "vanilla": + self.noise_generator = noise_generator + + def reset_noise_generator_state(self): + if hasattr(self, "noise_generator") and hasattr(self.noise_generator, "reset_noise"): + self.noise_generator.reset_noise_generator_state() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + # the image input for the controlnet branch + image: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[ + int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + precision: str = "16", + mask_generator=None, + no_text_condition_control: bool = False, + weight_control_sample: float = 1.0, + use_controlnet_mask: bool = False, + skip_controlnet_branch: bool = False, + img_cond_resampler=None, + img_cond_encoder=None, + input_frames_conditioning=None, + cfg_text_image: bool = False, + use_of: bool = False, + ** kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + controlnet_mask = None + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + # import pdb + # pdb.set_trace() + + if img_cond_resampler is not None and image is not None: + bsz = image.shape[0] + image_for_conditioniong = rearrange( + input_frames_conditioning, "B F C W H -> (B F) C W H") + image_enc = img_cond_encoder(image_for_conditioniong) + img_cond = img_cond_resampler(image_enc, batch_size=bsz) + image_enc_unc = img_cond_encoder( + torch.zeros_like(image_for_conditioniong)) + img_cond_unc = img_cond_resampler(image_enc_unc, batch_size=bsz) + else: + img_cond = None + img_cond_unc = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + img_cond=img_cond, + img_cond_unc=img_cond_unc + ) + skip_conditioning = image is None or skip_controlnet_branch + # import pdb + # pdb.set_trace() + if not skip_conditioning: + num_condition_frames = image.shape[1] + image, image_vq_enc = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + cfg_text_image=cfg_text_image, + ) + if len(image.shape) == 5: + image = rearrange(image, "B F C H W -> (B F) C H W") + if use_controlnet_mask: + # num_condition_frames = all possible frames, e.g. 16 + assert num_condition_frames == num_frames + image = rearrange( + image, "(B F) C H W -> B F C H W", F=num_condition_frames) + # image = torch.cat([image, image], dim=1) + controlnet_mask = torch.zeros( + (image.shape[0], num_frames), device=image.device, dtype=image.dtype) + # TODO HARDCODED number of frames! + controlnet_mask[:, :8] = 1.0 + image = rearrange(image, "B F C H W -> (B F) C H W") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + of_channels = 2 if use_of else 0 + num_channels_ctrl = self.unet.config.in_channels + num_channels_latents = num_channels_ctrl + of_channels + if not skip_conditioning: + image_vq_enc = rearrange( + image_vq_enc, "B F C H W -> B C F H W ", F=num_condition_frames) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if self.unet.concat: + image_latents = self.vae.encode(rearrange( + image, "B F C W H -> (B F) C W H")).latent_dist.sample() * self.vae.config.scaling_factor + image_latents = rearrange( + image_latents, "(B F) C W H -> B C F W H", B=latents.shape[0]) + image_shape = image_latents.shape + image_shape = [ax_dim for ax_dim in image_shape] + image_shape[2] = 16-image_shape[2] + image_latents = torch.cat([image_latents, torch.zeros( + image_shape, dtype=image_latents.dtype, device=image_latents.device)], dim=2) + controlnet_mask = torch.zeros( + image_latents.shape, device=image_latents.device, dtype=image_latents.dtype) + controlnet_mask[:, :, :8] = 1.0 + image_latents = image_latents * controlnet_mask + # torch.cat([latents, image_latents, controlnet_mask[:, :1]], dim=1) + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + + if mask_generator is not None: + attention_mask = mask_generator.get_mask( + device=latents.device, precision=precision) + else: + attention_mask = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t) + + if self.unet.concat: + latent_model_input = torch.cat([latent_model_input, image_latents.repeat( + 2, 1, 1, 1, 1), controlnet_mask[:, :1].repeat(2, 1, 1, 1, 1)], dim=1) + if not skip_conditioning: + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input[:, :num_channels_ctrl], + t, + encoder_hidden_states=prompt_embeds if (not no_text_condition_control) else torch.stack([ + prompt_embeds[0], prompt_embeds[0]]), + controlnet_cond=image, + attention_mask=attention_mask, + vq_gan=self.vae, + weight_control_sample=weight_control_sample, + return_dict=False, + controlnet_mask=controlnet_mask, + ) + else: + down_block_res_samples = None + mid_block_res_sample = None + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + down_block_additional_residuals=[ + sample.to(dtype=latent_model_input.dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=mid_block_res_sample.to( + dtype=latent_model_input.dtype) if mid_block_res_sample is not None else None, + fps=None, + + ).sample + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk( + 2) + noise_pred = noise_pred_uncond + guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape( + bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_step = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs) + latents = scheduler_step.prev_sample + + # reshape latents back + latents = latents[None, :].reshape( + bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents_video = latents[:, :num_channels_ctrl] + if of_channels > 0: + latents_of = latents[:, num_channels_ctrl:] + latents_of = rearrange(latents_of, "B C F W H -> (B F) C W H") + video_tensor = self.decode_latents(latents_video) + + if output_type == "pt": + video = video_tensor + elif output_type == "pt_t2v": + video = tensor2vid(video_tensor, output_type="pt") + video = rearrange(video, "f h w c -> f c h w") + elif output_type == "concat_image": + image_video = image.unsqueeze(2)[0:1].repeat([1, 1, 24, 1, 1]) + video_tensor_concat = torch.concat( + [image_video, video_tensor], dim=4) + video = tensor2vid(video_tensor_concat) + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + if of_channels == 0: + return video + else: + return video, latents_of + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/processor.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..f80d8bb5ff2d285167a7bc899ca108114c01ce88 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/processor.py @@ -0,0 +1,240 @@ +from einops import repeat, rearrange +from typing import Callable, Optional, Union +from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention +# from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention +from diffusers.utils.import_utils import is_xformers_available +from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams +import torch +import torch.nn.functional as F +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def set_use_memory_efficient_attention_xformers( + model, num_frame_conditioning: int, num_frames: int, attention_mask_params: AttentionMaskParams, valid: bool = True, attention_op: Optional[Callable] = None +) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_processor"): + + module.set_processor(XFormersAttnProcessor(attention_op=attention_op, + num_frame_conditioning=num_frame_conditioning, + num_frames=num_frames, + attention_mask_params=attention_mask_params,) + ) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in model.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + +class XFormersAttnProcessor: + def __init__(self, + attention_mask_params: AttentionMaskParams, + attention_op: Optional[Callable] = None, + num_frame_conditioning: int = None, + num_frames: int = None, + use_image_embedding: bool = False, + ): + self.attention_op = attention_op + self.num_frame_conditioning = num_frame_conditioning + self.num_frames = num_frames + self.temp_attend_on_neighborhood_of_condition_frames = attention_mask_params.temp_attend_on_neighborhood_of_condition_frames + self.spatial_attend_on_condition_frames = attention_mask_params.spatial_attend_on_condition_frames + self.use_image_embedding = use_image_embedding + + def __call__(self, attn: Attention, hidden_states, hidden_state_height=None, hidden_state_width=None, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + key_img = None + value_img = None + hidden_states_img = None + if attention_mask is not None: + attention_mask = repeat( + attention_mask, "1 F D -> B F D", B=batch_size) + + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + default_attention = not hasattr(attn, "is_spatial_attention") + if default_attention: + assert not self.temp_attend_on_neighborhood_of_condition_frames, "special attention must be implemented with new interface" + assert not self.spatial_attend_on_condition_frames, "special attention must be implemented with new interface" + is_spatial_attention = attn.is_spatial_attention if hasattr( + attn, "is_spatial_attention") else False + use_image_embedding = attn.use_image_embedding if hasattr( + attn, "use_image_embedding") else False + + if is_spatial_attention and use_image_embedding and attn.cross_attention_mode: + assert not self.spatial_attend_on_condition_frames, "Not implemented together with image embedding" + + alpha = attn.alpha + encoder_hidden_states_txt = encoder_hidden_states[:, :77, :] + + encoder_hidden_states_mixed = attn.conv(encoder_hidden_states) + encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed) + encoder_hidden_states = encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + + + + if not default_attention and not is_spatial_attention and self.temp_attend_on_neighborhood_of_condition_frames and not attn.cross_attention_mode: + # normal attention + query_condition = query[:, :self.num_frame_conditioning] + query_condition = attn.head_to_batch_dim( + query_condition).contiguous() + key_condition = key + value_condition = value + key_condition = attn.head_to_batch_dim(key_condition).contiguous() + value_condition = attn.head_to_batch_dim( + value_condition).contiguous() + hidden_states_condition = xformers.ops.memory_efficient_attention( + query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale + ) + hidden_states_condition = hidden_states_condition.to(query.dtype) + hidden_states_condition = attn.batch_to_head_dim( + hidden_states_condition) + # + query_uncondition = query[:, self.num_frame_conditioning:] + + key = key[:, :self.num_frame_conditioning] + value = value[:, :self.num_frame_conditioning] + key = rearrange(key, "(B W H) F C -> B W H F C", + H=hidden_state_height, W=hidden_state_width) + value = rearrange(value, "(B W H) F C -> B W H F C", + H=hidden_state_height, W=hidden_state_width) + + keys = [] + values = [] + for shifts_width in [-1, 0, 1]: + for shifts_height in [-1, 0, 1]: + keys.append(torch.roll(key, shifts=( + shifts_width, shifts_height), dims=(1, 2))) + values.append(torch.roll(value, shifts=( + shifts_width, shifts_height), dims=(1, 2))) + key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C") + value = rearrange(torch.cat(values, dim=3), + 'B W H F C -> (B W H) F C') + + query = attn.head_to_batch_dim(query_uncondition).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = torch.cat( + [hidden_states_condition, hidden_states], dim=1) + elif not default_attention and is_spatial_attention and self.spatial_attend_on_condition_frames and not attn.cross_attention_mode: + # (B F) W H C -> B F W H C + query_condition = rearrange( + query, "(B F) S C -> B F S C", F=self.num_frames) + query_condition = query_condition[:, :self.num_frame_conditioning] + query_condition = rearrange( + query_condition, "B F S C -> (B F) S C") + query_condition = attn.head_to_batch_dim( + query_condition).contiguous() + + key_condition = rearrange( + key, "(B F) S C -> B F S C", F=self.num_frames) + key_condition = key_condition[:, :self.num_frame_conditioning] + key_condition = rearrange(key_condition, "B F S C -> (B F) S C") + + value_condition = rearrange( + value, "(B F) S C -> B F S C", F=self.num_frames) + value_condition = value_condition[:, :self.num_frame_conditioning] + value_condition = rearrange( + value_condition, "B F S C -> (B F) S C") + + key_condition = attn.head_to_batch_dim(key_condition).contiguous() + value_condition = attn.head_to_batch_dim( + value_condition).contiguous() + hidden_states_condition = xformers.ops.memory_efficient_attention( + query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale + ) + hidden_states_condition = hidden_states_condition.to(query.dtype) + hidden_states_condition = attn.batch_to_head_dim( + hidden_states_condition) + + query_uncondition = rearrange( + query, "(B F) S C -> B F S C", F=self.num_frames) + query_uncondition = query_uncondition[:, + self.num_frame_conditioning:] + key_uncondition = rearrange( + key, "(B F) S C -> B F S C", F=self.num_frames) + value_uncondition = rearrange( + value, "(B F) S C -> B F S C", F=self.num_frames) + key_uncondition = key_uncondition[:, + self.num_frame_conditioning-1, None] + value_uncondition = value_uncondition[:, + self.num_frame_conditioning-1, None] + # if self.trainer.training: + # import pdb + # pdb.set_trace() + # print("now") + query_uncondition = rearrange( + query_uncondition, "B F S C -> (B F) S C") + key_uncondition = repeat(rearrange( + key_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning) + value_uncondition = repeat(rearrange( + value_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning) + query_uncondition = attn.head_to_batch_dim( + query_uncondition).contiguous() + key_uncondition = attn.head_to_batch_dim( + key_uncondition).contiguous() + value_uncondition = attn.head_to_batch_dim( + value_uncondition).contiguous() + hidden_states_uncondition = xformers.ops.memory_efficient_attention( + query_uncondition, key_uncondition, value_uncondition, attn_bias=None, op=self.attention_op, scale=attn.scale + ) + hidden_states_uncondition = hidden_states_uncondition.to( + query.dtype) + hidden_states_uncondition = attn.batch_to_head_dim( + hidden_states_uncondition) + hidden_states = torch.cat([rearrange(hidden_states_condition, "(B F) S C -> B F S C", F=self.num_frame_conditioning), rearrange( + hidden_states_uncondition, "(B F) S C -> B F S C", F=self.num_frames-self.num_frame_conditioning)], dim=1) + hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C") + else: + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_2d.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..f57509308cf25c160e14e7f6530a678b1cfa2da6 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_2d.py @@ -0,0 +1,333 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.modeling_utils import ModelMixin + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + use_image_embedding: bool = False, + unet_params=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = ( + in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", + deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + is_spatial_attention=True, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm( + inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear( + inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute( + 0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute( + 0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape( + batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape( + batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1( + F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out( + hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, + self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * + self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e7b3d8979e75c0f624e13802f6497a88f89b83 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal.py @@ -0,0 +1,190 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +# from diffusers.models.attention import BasicTransformerBlock +# from t2v_enhanced.model.diffusers_conditional.models.attention import BasicTransformerBlock +from diffusers.models.modeling_utils import ModelMixin +from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import BasicTransformerBlock + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + is_spatial_attention=False, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + attention_mask=None, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape( + batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( + batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs["hidden_state_height"] = height + cross_attention_kwargs["hidden_state_width"] = width + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + attention_mask=attention_mask, + encoder_attention_mask=attention_mask, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape( + batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py new file mode 100644 index 0000000000000000000000000000000000000000..53361f056cf4c2e110792b355e4512bf2dce15cf --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py @@ -0,0 +1,182 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +from diffusers.models.modeling_utils import ModelMixin + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + only_cross_attention=True, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape( + batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( + batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape( + batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e54475b95c5434e7634cabff312310497a7cefc7 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py @@ -0,0 +1,930 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +# from diffusers.models.transformer_2d import Transformer2DModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_2d import Transformer2DModel +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal import TransformerTemporalModel + + +# Assign gradient checkpoint function to simple variable for readability. +g_c = checkpoint.checkpoint + + +def is_video(num_frames, only_video=True): + if num_frames == 1 and not only_video: + return False + return num_frames > 1 + + +def custom_checkpoint(module, mode=None): + if mode == None: + raise ValueError('Mode for gradient checkpointing cannot be none.') + + custom_forward = None + + if mode == 'resnet': + def custom_forward(hidden_states, temb): + inputs = module(hidden_states, temb) + return inputs + + if mode == 'attn': + def custom_forward( + hidden_states, + encoder_hidden_states=None, + cross_attention_kwargs=None, + attention_mask=None, + ): + inputs = module( + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask + ) + return inputs.sample + + if mode == 'temp': + # If inputs are not None, we can assume that this was a single image. + # Otherwise, do temporal convolutions / attention. + def custom_forward(hidden_states, num_frames=None): + if not is_video(num_frames): + return hidden_states + else: + inputs = module( + hidden_states, + num_frames=num_frames + ) + if isinstance(module, TransformerTemporalModel): + return inputs.sample + else: + return inputs + + return custom_forward + + +def transformer_g_c(transformer, sample, num_frames): + sample = g_c(custom_checkpoint(transformer, mode='temp'), + sample, num_frames, use_reentrant=False, + ) + return sample + + +def cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=False, + attention_mask=None, +): + + def ordered_g_c(idx): + + # Self and CrossAttention + if idx == 0: + return g_c(custom_checkpoint(attn, mode='attn'), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + use_reentrant=False + ) + + # Temporal Self and CrossAttention + if idx == 1: + return g_c(custom_checkpoint(temp_attn, mode='temp'), + hidden_states, + num_frames, + use_reentrant=False + ) + + # Resnets + if idx == 2: + return g_c(custom_checkpoint(resnet, mode='resnet'), + hidden_states, + temb, + use_reentrant=False + ) + + # Temporal Convolutions + if idx == 3: + return g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, + num_frames, + use_reentrant=False + ) + + # Here we call the function depending on the order in which they are called. + # For some layers, the orders are different, so we access the appropriate one by index. + + if not inverse_temp: + for idx in [0, 1, 2, 3]: + hidden_states = ordered_g_c(idx) + else: + for idx in [2, 3, 0, 1]: + hidden_states = ordered_g_c(idx) + + return hidden_states + + +def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames): + hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), + hidden_states, + temb, + use_reentrant=False + ) + hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, + num_frames, + use_reentrant=False + ) + return hidden_states + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_image_embedding=False, + unet_params=None, +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_image_embedding=False, + unet_params=None, +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + use_image_embedding=False, + unet_params=None, + ): + super().__init__() + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + self.resnets[0], + self.temp_convs[0], + hidden_states, + temb, + num_frames + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0]( + hidden_states, num_frames=num_frames) + + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, attention_mask=attention_mask, + + ).sample + + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv( + hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_image_embedding=False, + unet_params=None, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + layer_idx = 0 + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + if num_frames > 1: + hidden_states = temp_conv( + hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if num_frames > 1: + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, attention_mask=attention_mask, + ).sample + layer_idx += 1 + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + self.gradient_checkpointing = False + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv( + hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_image_embedding=False, + unet_params=None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if ( + i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat( + [hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv( + hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, + attention_mask=attention_mask, + ).sample + output_states += (hidden_states,) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + self.gradient_checkpointing = False + for i in range(num_layers): + res_skip_channels = in_channels if ( + i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + output_states = () + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat( + [hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv( + hidden_states, num_frames=num_frames) + output_states += (hidden_states,) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + output_states += (hidden_states,) + + return hidden_states, output_states diff --git a/t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_condition.py b/t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0bd7724750e9c30e6968af1e9f77ca004e65a2 --- /dev/null +++ b/t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_condition.py @@ -0,0 +1,635 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +# from diffusers.models.transformer_temporal import TransformerTemporalModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c +) +from t2v_enhanced.model.diffusers_conditional.models.controlnet.conditioning import ConditionalModel +from einops import rearrange +from t2v_enhanced.model.layers.conv_channel_extension import Conv2D_ExtendedChannels +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ( + "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + merging_mode: str = "addition", + use_image_embedding: bool = False, + use_fps_conditioning: bool = False, + unet_params=None, + ): + super().__init__() + channel_expansion = unet_params.use_of + self.concat = unet_params.concat + self.use_image_tokens = unet_params.use_image_tokens_main + self.image_encoder_name = type(unet_params.image_encoder).__name__ + self.use_image_embedding = use_image_embedding + self.sample_size = sample_size + self.gradient_checkpointing = False + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + ''' + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + ''' + self.conv_in = Conv2D_ExtendedChannels( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding, in_channel_extension=5 if self.concat else 0, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + self.use_fps_conditioning = use_fps_conditioning + if use_fps_conditioning: + fps_embed_dim = block_out_channels[0] * 4 + fps_input_dim = block_out_channels[0] + self.fps_embedding = TimestepEmbedding( + fps_input_dim, fps_embed_dim, act_fn=act_fn) + self.fps_proj = Timesteps(block_out_channels[0], True, 0) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + self.merging_mode = merging_mode + print("self.merging_mode", self.merging_mode) + if self.merging_mode.startswith("attention"): + self.cross_attention_merger_down_blocks = nn.ModuleList([]) + self.cross_attention_merger_mid_block = nn.ModuleList([]) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.down_blocks.append(down_block) + + if self.merging_mode.startswith("attention"): + for idx in range(3): + self.cross_attention_merger_down_blocks.append(ConditionalModel( + input_channels=input_channel if idx == 0 else output_channel, conditional_model=self.merging_mode.split("attention_")[1])) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + if self.merging_mode.startswith("attention"): + self.cross_attention_merger_mid_block = ConditionalModel( + input_channels=block_out_channels[-1], conditional_model=self.merging_mode.split("attention_")[1]) + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min( + i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + use_image_embedding=use_image_embedding, + unet_params=unet_params, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + ''' + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + ''' + self.conv_out = Conv2D_ExtendedChannels( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding, out_channel_extension=2 if channel_expansion else 0, + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * \ + [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError( + f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + self.mid_block.gradient_checkpointing = value + for module in self.down_blocks + self.up_blocks: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + fps: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info( + "Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + ''' + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + ''' + debug = False + if self.use_fps_conditioning: + + if torch.is_tensor(fps): + assert (fps > -1).all(), "FPS not set" + if len(fps.shape) == 0: + fps = fps[None].to(sample.device) + else: + assert (fps > -1), "FPS not set" + is_mps = sample.device.type == "mps" + if isinstance(fps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + fps = torch.tensor([fps], dtype=dtype, device=sample.device) + fps = fps.expand(sample.shape[0]) + fps_proj = self.fps_proj(fps) + fps_proj = fps_proj.to(dtype=self.dtype) + fps_emb = self.fps_embedding(fps_proj) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor( + [timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + batch_size = sample.shape[0] + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + if self.use_fps_conditioning: + fps_emb = fps_emb.repeat_interleave(repeats=num_frames, dim=0) + emb = emb + fps_emb + + if not self.use_image_tokens and encoder_hidden_states.shape[1] > 77: + encoder_hidden_states = encoder_hidden_states[:, :77] + # print(f"MAIN with tokens = {encoder_hidden_states.shape[1]}") + if encoder_hidden_states.shape[1] > 77: + # assert ( + # encoder_hidden_states.shape[1]-77) % num_frames == 0, f"Encoder shape {encoder_hidden_states.shape}. Num frames = {num_frames}" + context_text, context_img = encoder_hidden_states[:, + :77, :], encoder_hidden_states[:, 77:, :] + context_text = context_text.repeat_interleave( + repeats=num_frames, dim=0) + + if self.image_encoder_name == "FrozenOpenCLIPImageEmbedder": + context_img = context_img.repeat_interleave( + repeats=num_frames, dim=0) + else: + context_img = rearrange( + context_img, 'b (t l) c -> (b t) l c', t=num_frames) + + encoder_hidden_states = torch.cat( + [context_text, context_img], dim=1) + else: + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape( + (sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c( + self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in( + sample, num_frames=num_frames, attention_mask=attention_mask).sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + if self.merging_mode == "addition": + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + elif self.merging_mode.startswith("attention"): + for down_block_res_sample, down_block_additional_residual, merger in zip( + down_block_res_samples, down_block_additional_residuals, self.cross_attention_merger_down_blocks + ): + + down_block_res_sample = merger( + rearrange(down_block_res_sample, "(B F) C H W -> B F C H W", B=batch_size), rearrange(down_block_additional_residual, "(B F) C H W -> B F C H W", B=batch_size)) + down_block_res_sample = rearrange( + down_block_res_sample, "B F C H W -> (B F) C H W") + new_down_block_res_samples += (down_block_res_sample,) + elif self.merging_mode == "overwrite": + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + if self.merging_mode == "addition": + sample = sample + mid_block_additional_residual + elif self.merging_mode == "overwrite": + sample = sample + mid_block_additional_residual + elif self.merging_mode.startswith("attention"): + sample = self.cross_attention_merger_mid_block( + rearrange(sample, "(B F) C H W -> B F C H W", B=batch_size), rearrange(mid_block_additional_residual, "(B F) C H W -> B F C H W", B=batch_size)) + sample = rearrange(sample, "B F C H W -> (B F) C H W") + + if debug: + upblockout = (sample,) + # 5. up + # import pdb + # pdb.set_trace() + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len( + upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample, output_states = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, output_states = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + if debug: + upblockout += output_states + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape( + (-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/t2v_enhanced/model/flags.py b/t2v_enhanced/model/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4edac51898752e3a1069eb93ad23ea8b7c7b0a --- /dev/null +++ b/t2v_enhanced/model/flags.py @@ -0,0 +1 @@ +TORCH_DISTRIBUTED_DEBUG = DETAIL diff --git a/t2v_enhanced/model/layers/conv_channel_extension.py b/t2v_enhanced/model/layers/conv_channel_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..bed009c21f5958758d93af9af397754b9fab7a6f --- /dev/null +++ b/t2v_enhanced/model/layers/conv_channel_extension.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +from typing import Union +from torch.nn.common_types import _size_2_t + + +class Conv2D_SubChannels(nn.Conv2d): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None, + ) -> None: + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, device, dtype) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + + if prefix+"weight" in state_dict and ((state_dict[prefix+"weight"].shape[0] > self.out_channels) or (state_dict[prefix+"weight"].shape[1] > self.in_channels)): + print( + f"Model checkpoint has too many channels. Excluding channels of convolution {prefix}.") + if self.bias is not None: + bias = state_dict[prefix+"bias"][:self.out_channels] + state_dict[prefix+"bias"] = bias + del bias + + weight = state_dict[prefix+"weight"] + state_dict[prefix+"weight"] = weight[:self.out_channels, + :self.in_channels] + del weight + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +class Conv2D_ExtendedChannels(nn.Conv2d): + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None, + in_channel_extension: int = 0, + out_channel_extension: int = 0, + ) -> None: + super().__init__(in_channels+in_channel_extension, out_channels+out_channel_extension, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, device, dtype) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + print(f"Call extend channel loader with {prefix}") + if prefix+"weight" in state_dict and (state_dict[prefix+"weight"].shape[0] < self.out_channels or state_dict[prefix+"weight"].shape[1] < self.in_channels): + print( + f"Model checkpoint has insufficient channels. Extending channels of convolution {prefix} by adding zeros.") + if self.bias is not None: + bias = state_dict[prefix+"bias"] + state_dict[prefix+"bias"] = torch.cat( + [bias, torch.zeros(self.out_channels-len(bias), dtype=bias.dtype, layout=bias.layout, device=bias.device)]) + del bias + + weight = state_dict[prefix+"weight"] + extended_weight = torch.zeros(self.out_channels, self.in_channels, + weight.shape[2], weight.shape[3], device=weight.device, dtype=weight.dtype, layout=weight.layout) + extended_weight[:weight.shape[0], :weight.shape[1]] = weight + state_dict[prefix+"weight"] = extended_weight + del extended_weight + del weight + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +if __name__ == "__main__": + class MyModel(nn.Module): + + def __init__(self, conv_type: str, c_in, c_out, in_extension, out_extension) -> None: + super().__init__() + + if not conv_type == "normal": + + self.conv1 = Conv2D_ExtendedChannels( + c_in, c_out, 3, padding=1, in_channel_extension=in_extension, out_channel_extension=out_extension, bias=True) + + else: + self.conv1 = nn.Conv2d(c_in, c_out, 3, padding=1, bias=True) + + def forward(self, x): + return self.conv1(x) + + c_in = 9 + c_out = 12 + c_in_ext = 0 + c_out_ext = 3 + model = MyModel("normal", c_in, c_out, c_in_ext, c_out_ext) + + input = torch.randn((4, c_in+c_in_ext, 128, 128)) + out_normal = model(input[:, :c_in]) + torch.save(model.state_dict(), "model_dummy.py") + + model_2 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext) + model_2.load_state_dict(torch.load("model_dummy.py")) + out_model_2 = model_2(input) + out_special = out_model_2[:, :c_out] + + out_new = out_model_2[:, c_out:] + model_3 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext) + model_3.load_state_dict(model_2.state_dict()) + # out_model_2 = model_2(input) + # out_special = out_model_2[:, :c_out] + + print( + f"Difference: Forward pass with extended convolution minus initial convolution: {(out_normal-out_special).abs().max()}") + + print(f"Compared tensors with shape: ", + out_normal.shape, out_special.shape) + + if model_3.conv1.bias is not None: + criterion = nn.MSELoss() + + before_opt = model_3.conv1.bias.detach().clone() + target = torch.ones_like(out_model_2) + optimizer = torch.optim.SGD( + model_3.parameters(), lr=0.01, momentum=0.9) + for iter in range(10): + optimizer.zero_grad() + out = model_3(input) + loss = criterion(out, target) + loss.backward() + optimizer.step() + print( + f"Weights before and after are the same? {before_opt[c_out:].detach()} | {model_3.conv1.bias[c_out:].detach()} ") + print(model_3.conv1.bias, model_2.conv1.bias) diff --git a/t2v_enhanced/model/pl_module_extension.py b/t2v_enhanced/model/pl_module_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..544465e746e4d0b70b7a4d85c307ff0ff428e421 --- /dev/null +++ b/t2v_enhanced/model/pl_module_extension.py @@ -0,0 +1,297 @@ +import torch +from copy import deepcopy +from einops import repeat +import math + + +class FrameConditioning(): + def __init__(self, + add_frame_to_input: bool = False, + add_frame_to_layers: bool = False, + fill_zero: bool = False, + randomize_mask: bool = False, + concatenate_mask: bool = False, + injection_probability: float = 0.9, + ) -> None: + self.use = None + self.add_frame_to_input = add_frame_to_input + self.add_frame_to_layers = add_frame_to_layers + self.fill_zero = fill_zero + self.randomize_mask = randomize_mask + self.concatenate_mask = concatenate_mask + self.injection_probability = injection_probability + self.add_frame_to_input or self.add_frame_to_layers + + assert not add_frame_to_layers or not add_frame_to_input + + def set_random_mask(self, random_mask: bool): + frame_conditioning = deepcopy(self) + frame_conditioning.randomize_mask = random_mask + return frame_conditioning + + @property + def use(self): + return self.add_frame_to_input or self.add_frame_to_layers + + @use.setter + def use(self, value): + if value is not None: + raise NotImplementedError("Direct access not allowed") + + def attach_video_frames(self, pl_module, z_0: torch.Tensor = None, batch: torch.Tensor = None, random_mask: bool = False): + assert self.fill_zero, "Not filling with zero not implemented yet" + n_frames_inference = self.inference_params.video_length + with torch.no_grad(): + if z_0 is None: + assert batch is not None + z_0 = pl_module.encode_frame(batch) + assert n_frames_inference == z_0.shape[1], "For frame injection, the number of frames sampled by the dataloader must match the number of frames used for video generation" + shape = list(z_0.shape) + + shape[1] = pl_module.inference_params.video_length + M = torch.zeros(shape, dtype=z_0.dtype, + device=pl_module.device) # [B F C W H] + bsz = z_0.shape[0] + if random_mask: + p_inject_frame = self.injection_probability + use_masks = torch.bernoulli( + torch.tensor(p_inject_frame).repeat(bsz)).long() + keep_frame_idx = torch.randint( + 0, n_frames_inference, (bsz,), device=pl_module.device).long() + else: + use_masks = torch.ones((bsz,), device=pl_module.device).long() + # keep only first frame + keep_frame_idx = 0 * use_masks + frame_idx = [] + + for batch_idx, (keep_frame, use_mask) in enumerate(zip(keep_frame_idx, use_masks)): + M[batch_idx, keep_frame] = use_mask + frame_idx.append(keep_frame if use_mask == 1 else -1) + + x0 = z_0*M + if self.concatenate_mask: + # flatten mask + M = M[:, :, 0, None] + x0 = torch.cat([x0, M], dim=2) + if getattr(pl_module.opt_params.noise_decomposition, "use", False) and random_mask: + assert x0.shape[0] == 1, "randomizing frame injection with noise decomposition not implemented for batch size >1" + return x0, frame_idx + + +class NoiseDecomposition(): + + def __init__(self, + use: bool = False, + random_frame: bool = False, + lambda_f: float = 0.5, + use_base_model: bool = True, + ): + self.use = use + self.random_frame = random_frame + self.lambda_f = lambda_f + self.use_base_model = use_base_model + + def get_loss(self, x0, unet_base, unet, noise_scheduler, frame_idx, z_t_base, timesteps, encoder_hidden_states, base_noise, z_t_residual, composed_noise): + if x0 is not None: + # x0.shape = [B,F,C,W,H], if extrapolation_params.fill_zero=true, only one frame per batch non-zero + assert not self.random_frame + + # TODO add x0 injection + x0_base = [] + for batch_idx, frame in enumerate(frame_idx): + x0_base.append(x0[batch_idx, frame, None, None]) + + x0_base = torch.cat(x0_base, dim=0) + x0_residual = repeat( + x0[:, 0], "B C W H -> B F C W H", F=x0.shape[1]-1) + else: + x0_residual = None + + if self.use_base_model: + base_pred = unet_base(z_t_base, timesteps, + encoder_hidden_states, x0=x0_base).sample + else: + base_pred = base_noise + + timesteps_alphas = [ + noise_scheduler.alphas_cumprod[t.cpu()] for t in timesteps] + timesteps_alphas = torch.stack( + timesteps_alphas).to(base_pred.device) + timesteps_alphas = repeat(timesteps_alphas, "B -> B F C W H", + F=base_pred.shape[1], C=base_pred.shape[2], W=base_pred.shape[3], H=base_pred.shape[4]) + base_correction = math.sqrt( + lambda_f) * torch.sqrt(1-timesteps_alphas) * base_pred + + z_t_residual_dash = z_t_residual - base_correction + + residual_pred = unet( + z_t_residual_dash, timesteps, encoder_hidden_states, x0=x0_residual).sample + composed_pred = math.sqrt( + lambda_f)*base_pred.detach() + math.sqrt(1-lambda_f) * residual_pred + + loss_residual = torch.nn.functional.mse_loss( + composed_noise.float(), composed_pred.float(), reduction=reduction) + if self.use_base_model: + loss_base = torch.nn.functional.mse_loss( + base_noise.float(), base_pred.float(), reduction=reduction) + loss = loss_residual+loss_base + else: + loss = loss_residual + return loss + + def add_noise(self, z_base, base_noise, z_residual, composed_noise, noise_scheduler, timesteps): + z_t_base = noise_scheduler.add_noise( + z_base, base_noise, timesteps) + z_t_residual = noise_scheduler.add_noise( + z_residual, composed_noise, timesteps) + return z_t_base, z_t_residual + + def split_latent_into_base_residual(self, z_0, pl_module, noise_generator): + if self.random_frame: + raise NotImplementedError("Must be synced with x0 mask!") + fr_select = torch.randint( + 0, z_0.shape[1], (bsz,), device=pl_module.device).long() + z_base = z_0[:, fr_Select, None] + fr_residual = [fr for fr in range( + z_0.shape[1]) if fr != fr_select] + z_residual = z_0[:, fr_residual, None] + else: + if not pl_module.unet_params.frame_conditioning.randomize_mask: + z_base = z_0[:, 0, None] + z_residual = z_0[:, 1:] + else: + z_base = [] + for batch_idx, frame_at_batch in enumerate(frame_idx): + z_base.append( + z_0[batch_idx, frame_at_batch, None, None]) + z_base = torch.cat(z_base, dim=0) + # z_residual = z_0[[:, 1:] + z_residual = [] + + for batch_idx, frame_idx_batch in enumerate(frame_idx): + z_residual_batch = [] + for frame in range(z_0.shape[1]): + if frame_idx_batch != frame: + z_residual_batch.append( + z_0[batch_idx, frame, None, None]) + z_residual_batch = torch.cat( + z_residual_batch, dim=1) + z_residual.append(z_residual_batch) + z_residual = torch.cat(z_residual, dim=0) + base_noise = noise_generator.sample_noise(z_base) # b_t + residual_noise = noise_generator.sample_noise(z_residual) # r^f_t + lambda_f = self.lambda_f + composed_noise = math.sqrt( + lambda_f) * base_noise + math.sqrt(1-lambda_f) * residual_noise # dimension issue? + + return z_base, base_noise, z_residual, composed_noise + + +class NoiseGenerator(): + + def __init__(self, mode="vanilla") -> None: + self.mode = mode + + def set_seed(self, seed: int): + self.seed = seed + + def reset_seed(self, seed: int): + pass + + def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None): + + assert (z_0 is not None) != ( + shape is not None), f"either z_0 must be None, or shape must be None. Both provided." + kwargs = {} + + if z_0 is None: + if device is not None: + kwargs["device"] = device + if dtype is not None: + kwargs["dtype"] = dtype + + else: + kwargs["device"] = z_0.device + kwargs["dtype"] = z_0.dtype + shape = z_0.shape + + if generator is not None: + kwargs["generator"] = generator + + B, F, C, W, H = shape + + if self.mode == "vanilla": + noise = torch.randn( + shape, **kwargs) + elif self.mode == "free_noise": + noise = torch.randn(shape, **kwargs) + if noise.shape[1] > 4: + # HARD CODED + noise = noise[:, :8] + noise = torch.cat( + [noise, noise[:, torch.randperm(noise.shape[1])]], dim=1) + elif noise.shape[2] > 4: + noise = noise[:, :, :8] + noise = torch.cat( + [noise, noise[:, :, torch.randperm(noise.shape[2])]], dim=2) + else: + raise NotImplementedError( + f"Shape of noise vector not as expected {noise.shape}") + elif self.mode == "equal": + shape = list(shape) + shape[1] = 1 + noise_init = torch.randn( + shape, **kwargs) + shape[1] = F + noise = torch.zeros( + shape, device=noise_init.device, dtype=noise_init.dtype) + for fr in range(F): + noise[:, fr] = noise_init[:, 0] + elif self.mode == "fusion": + shape = list(shape) + shape[1] = 1 + noise_init = torch.randn( + shape, **kwargs) + noises = [] + noises.append(noise_init) + for fr in range(F-1): + + shift = 2*(fr+1) + local_copy = noise_init + shifted_noise = torch.cat( + [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) + noises.append(math.sqrt(0.2)*shifted_noise + + math.sqrt(1-0.2)*torch.rand(shape, **kwargs)) + noise = torch.cat(noises, dim=1) + + elif self.mode == "motion_dynamics" or self.mode == "equal_noise_per_sequence": + + shape = list(shape) + normal_frames = 1 + shape[1] = normal_frames + init_noise = torch.randn( + shape, **kwargs) + noises = [] + noises.append(init_noise) + init_noise = init_noise[:, -1, None] + print(f"UPDATE with noise = {init_noise.shape}") + + if self.mode == "motion_dynamics": + for fr in range(F-normal_frames): + + shift = 2*(fr+1) + print(fr, shift) + local_copy = init_noise + shifted_noise = torch.cat( + [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) + noises.append(shifted_noise) + elif self.mode == "equal_noise_per_sequence": + for fr in range(F-1): + noises.append(init_noise) + else: + raise NotImplementedError() + # noises[0] = noises[0] * 0 + noise = torch.cat(noises, dim=1) + print(noise.shape) + + return noise diff --git a/t2v_enhanced/model/pl_module_params_controlnet.py b/t2v_enhanced/model/pl_module_params_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..81463056c7b33af75c6b9cd076e8aab18a78ac21 --- /dev/null +++ b/t2v_enhanced/model/pl_module_params_controlnet.py @@ -0,0 +1,356 @@ +from typing import Union, Any, Dict, List, Optional, Callable +from t2v_enhanced.model import pl_module_extension +from t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder import AbstractEncoder +from t2v_enhanced.model.requires_grad_setter import LayerConfig as LayerConfigNew +from t2v_enhanced.model import video_noise_generator + + +def auto_str(cls): + def __str__(self): + return '%s(%s)' % ( + type(self).__name__, + ', '.join('%s=%s' % item for item in vars(self).items()) + ) + cls.__str__ = __str__ + return cls + + +class LayerConfig(): + def __init__(self, + update_with_full_lr: Optional[Union[List[str], + List[List[str]]]] = None, + exclude: Optional[List[str]] = None, + deactivate_all_grads: bool = True, + ) -> None: + self.deactivate_all_grads = deactivate_all_grads + if exclude is not None: + self.exclude = exclude + if update_with_full_lr is not None: + self.update_with_full_lr = update_with_full_lr + + def __str__(self) -> str: + str = f"Deactivate all gradients first={self.deactivate_all_grads}. " + if hasattr(self, "update_with_full_lr"): + str += f"Then activating gradients for: {self.update_with_full_lr}. " + if hasattr(self, "exclude"): + str += f"Finally, excluding: {self.exclude}. " + return str + + +class OptimizerParams(): + def __init__(self, + learning_rate: float, + # Default value due to legacy + layers_config: Union[LayerConfig, LayerConfigNew] = None, + layers_config_base: LayerConfig = None, # Default value due to legacy + use_warmup: bool = False, + warmup_steps: int = 10000, + warmup_start_factor: float = 1e-5, + learning_rate_spatial: float = 0.0, + use_8_bit_adam: bool = False, + noise_generator: Union[pl_module_extension.NoiseGenerator, + video_noise_generator.NoiseGenerator] = None, + noise_decomposition: pl_module_extension.NoiseDecomposition = None, + perceptual_loss: bool = False, + noise_offset: float = 0.0, + split_opt_by_node: bool = False, + reset_prediction_type_to_eps: bool = False, + train_val_sampler_may_differ: bool = False, + measure_similarity: bool = False, + similarity_loss: bool = False, + similarity_loss_weight: float = 1.0, + loss_conditional_weight: float = 0.0, + loss_conditional_weight_convex: bool = False, + loss_conditional_change_after_step: int = 0, + mask_conditional_frames: bool = False, + sample_from_noise: bool = True, + mask_alternating: bool = False, + uncondition_freq: int = -1, + no_text_condition_control: bool = False, + inject_image_into_input: bool = False, + inject_at_T: bool = False, + resampling_steps: int = 1, + control_freq_in_resample: int = 1, + resample_to_T: bool = False, + adaptive_loss_reweight: bool = False, + load_resampler_from_ckpt: str = "", + skip_controlnet_branch: bool = False, + use_fps_conditioning: bool = False, + num_frame_embeddings_range: int = 16, + start_frame_training: int = 0, + start_frame_ctrl: int = 0, + load_trained_base_model_and_resampler_from_ckpt: str = "", + load_trained_controlnet_from_ckpt: str = "", + # fill_up_frame_to_video: bool = False, + ) -> None: + self.use_warmup = use_warmup + self.warmup_steps = warmup_steps + self.warmup_start_factor = warmup_start_factor + self.learning_rate_spatial = learning_rate_spatial + self.learning_rate = learning_rate + self.use_8_bit_adam = use_8_bit_adam + self.layers_config = layers_config + self.noise_generator = noise_generator + self.perceptual_loss = perceptual_loss + self.noise_decomposition = noise_decomposition + self.noise_offset = noise_offset + self.split_opt_by_node = split_opt_by_node + self.reset_prediction_type_to_eps = reset_prediction_type_to_eps + self.train_val_sampler_may_differ = train_val_sampler_may_differ + self.measure_similarity = measure_similarity + self.similarity_loss = similarity_loss + self.similarity_loss_weight = similarity_loss_weight + self.loss_conditional_weight = loss_conditional_weight + self.loss_conditional_change_after_step = loss_conditional_change_after_step + self.mask_conditional_frames = mask_conditional_frames + self.loss_conditional_weight_convex = loss_conditional_weight_convex + self.sample_from_noise = sample_from_noise + self.layers_config_base = layers_config_base + self.mask_alternating = mask_alternating + self.uncondition_freq = uncondition_freq + self.no_text_condition_control = no_text_condition_control + self.inject_image_into_input = inject_image_into_input + self.inject_at_T = inject_at_T + self.resampling_steps = resampling_steps + self.control_freq_in_resample = control_freq_in_resample + self.resample_to_T = resample_to_T + self.adaptive_loss_reweight = adaptive_loss_reweight + self.load_resampler_from_ckpt = load_resampler_from_ckpt + self.skip_controlnet_branch = skip_controlnet_branch + self.use_fps_conditioning = use_fps_conditioning + self.num_frame_embeddings_range = num_frame_embeddings_range + self.start_frame_training = start_frame_training + self.load_trained_base_model_and_resampler_from_ckpt = load_trained_base_model_and_resampler_from_ckpt + self.load_trained_controlnet_from_ckpt = load_trained_controlnet_from_ckpt + self.start_frame_ctrl = start_frame_ctrl + if start_frame_ctrl < 0: + print("new format start frame cannot be negative") + exit() + + # self.fill_up_frame_to_video = fill_up_frame_to_video + + @property + def learning_rate_spatial(self): + return self._learning_rate_spatial + + # legacy code that maps the state None or '-1' to '0.0' + # so 0.0 indicated no spatial learning rate is selected + @learning_rate_spatial.setter + def learning_rate_spatial(self, value): + if value is None or value == -1: + value = 0 + self._learning_rate_spatial = value + + +# Legacy class +class SchedulerParams(): + def __init__(self, + use_warmup: bool = False, + warmup_steps: int = 10000, + warmup_start_factor: float = 1e-5, + ) -> None: + self.use_warmup = use_warmup + self.warmup_steps = warmup_steps + self.warmup_start_factor = warmup_start_factor + + + +class CrossFrameAttentionParams(): + + def __init__(self, attent_on: List[int], masking=False) -> None: + self.attent_on = attent_on + self.masking = masking + + +class InferenceParams(): + def __init__(self, + width: int, + height: int, + video_length: int, + guidance_scale: float = 7.5, + use_dec_scaling: bool = True, + frame_rate: int = 2, + num_inference_steps: int = 50, + eta: float = 0.0, + n_autoregressive_generations: int = 1, + mode: str = "long_video", + start_from_real_input: bool = True, + eval_loss_metrics: bool = False, + scheduler_cls: str = "", + negative_prompt: str = "", + conditioning_from_all_past: bool = False, + validation_samples: int = 80, + conditioning_type: str = "last_chunk", + result_formats: List[str] = ["eval_gif", "gif", "mp4"], + concat_video: bool = True, + seed: int = 33, + ): + self.width = width + self.height = height + self.video_length = video_length if isinstance( + video_length, int) else int(video_length) + self.guidance_scale = guidance_scale + self.use_dec_scaling = use_dec_scaling + self.frame_rate = frame_rate + self.num_inference_steps = num_inference_steps + self.eta = eta + self.negative_prompt = negative_prompt + self.n_autoregressive_generations = n_autoregressive_generations + self.mode = mode + self.start_from_real_input = start_from_real_input + self.eval_loss_metrics = eval_loss_metrics + self.scheduler_cls = scheduler_cls + self.conditioning_from_all_past = conditioning_from_all_past + self.validation_samples = validation_samples + self.conditioning_type = conditioning_type + self.result_formats = result_formats + self.concat_video = concat_video + self.seed = seed + + def to_dict(self): + + keys = [entry for entry in dir(self) if not callable(getattr( + self, entry)) and not entry.startswith("__")] + + result_dict = {} + for key in keys: + result_dict[key] = getattr(self, key) + return result_dict + + +@auto_str +class AttentionMaskParams(): + + def __init__(self, + temporal_self_attention_only_on_conditioning: bool = False, + temporal_self_attention_mask_included_itself: bool = False, + spatial_attend_on_condition_frames: bool = False, + temp_attend_on_neighborhood_of_condition_frames: bool = False, + temp_attend_on_uncond_include_past: bool = False, + ) -> None: + self.temporal_self_attention_mask_included_itself = temporal_self_attention_mask_included_itself + self.spatial_attend_on_condition_frames = spatial_attend_on_condition_frames + self.temp_attend_on_neighborhood_of_condition_frames = temp_attend_on_neighborhood_of_condition_frames + self.temporal_self_attention_only_on_conditioning = temporal_self_attention_only_on_conditioning + self.temp_attend_on_uncond_include_past = temp_attend_on_uncond_include_past + + assert not temp_attend_on_neighborhood_of_condition_frames or not temporal_self_attention_only_on_conditioning + + +class UNetParams(): + + def __init__(self, + conditioning_embedding_out_channels: List[int], + ckpt_spatial_layers: str = "", + pipeline_repo: str = "", + unet_from_diffusers: bool = True, + spatial_latent_input: bool = False, + num_frame_conditioning: int = 1, + pipeline_class: str = "t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline", + frame_expansion: str = "last_frame", + downsample_controlnet_cond: bool = True, + num_frames: int = 1, + pre_transformer_in_cond: bool = False, + num_tranformers: int = 1, + zero_conv_3d: bool = False, + merging_mode: str = "addition", + compute_only_conditioned_frames: bool = False, + condition_encoder: str = "", + zero_conv_mode: str = "2d", + clean_model: bool = False, + merging_mode_base: str = "addition", + attention_mask_params: AttentionMaskParams = None, + attention_mask_params_base: AttentionMaskParams = None, + modelscope_input_format: bool = True, + temporal_self_attention_only_on_conditioning: bool = False, + temporal_self_attention_mask_included_itself: bool = False, + use_post_merger_zero_conv: bool = False, + weight_control_sample: float = 1.0, + use_controlnet_mask: bool = False, + random_mask_shift: bool = False, + random_mask: bool = False, + use_resampler: bool = False, + unet_from_pipe: bool = False, + unet_operates_on_2d: bool = False, + image_encoder: str = "CLIP", + use_standard_attention_processor: bool = True, + num_frames_before_chunk: int = 0, + resampler_type: str = "single_frame", + resampler_cls: str = "", + resampler_merging_layers: int = 1, + image_encoder_obj: AbstractEncoder = None, + cfg_text_image: bool = False, + aggregation: str = "last_out", + resampler_random_shift: bool = False, + img_cond_alpha_per_frame: bool = False, + num_control_input_frames: int = -1, + use_image_encoder_normalization: bool = False, + use_of: bool = False, + ema_param: float = -1.0, + concat: bool = False, + use_image_tokens_main: bool = True, + use_image_tokens_ctrl: bool = False, + ): + + self.ckpt_spatial_layers = ckpt_spatial_layers + self.pipeline_repo = pipeline_repo + self.unet_from_diffusers = unet_from_diffusers + self.spatial_latent_input = spatial_latent_input + self.pipeline_class = pipeline_class + self.num_frame_conditioning = num_frame_conditioning + if num_control_input_frames == -1: + self.num_control_input_frames = num_frame_conditioning + else: + self.num_control_input_frames = num_control_input_frames + + self.conditioning_embedding_out_channels = conditioning_embedding_out_channels + self.frame_expansion = frame_expansion + self.downsample_controlnet_cond = downsample_controlnet_cond + self.num_frames = num_frames + self.pre_transformer_in_cond = pre_transformer_in_cond + self.num_tranformers = num_tranformers + self.zero_conv_3d = zero_conv_3d + self.merging_mode = merging_mode + self.compute_only_conditioned_frames = compute_only_conditioned_frames + self.clean_model = clean_model + self.condition_encoder = condition_encoder + self.zero_conv_mode = zero_conv_mode + self.merging_mode_base = merging_mode_base + self.modelscope_input_format = modelscope_input_format + assert not temporal_self_attention_only_on_conditioning, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead." + assert not temporal_self_attention_mask_included_itself, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead." + if attention_mask_params is not None and attention_mask_params_base is None: + attention_mask_params_base = attention_mask_params + if attention_mask_params is None: + attention_mask_params = AttentionMaskParams() + if attention_mask_params_base is None: + attention_mask_params_base = AttentionMaskParams() + self.attention_mask_params = attention_mask_params + self.attention_mask_params_base = attention_mask_params_base + self.weight_control_sample = weight_control_sample + self.use_controlnet_mask = use_controlnet_mask + self.random_mask_shift = random_mask_shift + self.random_mask = random_mask + self.use_resampler = use_resampler + self.unet_from_pipe = unet_from_pipe + self.unet_operates_on_2d = unet_operates_on_2d + self.image_encoder = image_encoder_obj + self.use_standard_attention_processor = use_standard_attention_processor + self.num_frames_before_chunk = num_frames_before_chunk + self.resampler_type = resampler_type + self.resampler_cls = resampler_cls + self.resampler_merging_layers = resampler_merging_layers + self.cfg_text_image = cfg_text_image + self.aggregation = aggregation + self.resampler_random_shift = resampler_random_shift + self.img_cond_alpha_per_frame = img_cond_alpha_per_frame + self.use_image_encoder_normalization = use_image_encoder_normalization + self.use_of = use_of + self.ema_param = ema_param + self.concat = concat + self.use_image_tokens_main = use_image_tokens_main + self.use_image_tokens_ctrl = use_image_tokens_ctrl + assert not use_post_merger_zero_conv + + if spatial_latent_input: + assert unet_from_diffusers, "Spatial latent input only implemented by original diffusers model. Set 'model.unet_params.unet_from_diffusers=True'." diff --git a/t2v_enhanced/model/requires_grad_setter.py b/t2v_enhanced/model/requires_grad_setter.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6cb997ac3b91c8a76390b5885985678bb402cb --- /dev/null +++ b/t2v_enhanced/model/requires_grad_setter.py @@ -0,0 +1,36 @@ +from typing import Union, Any, Dict, List, Optional, Tuple +import pytorch_lightning as pl + + +class LayerConfig(): + def __init__(self, + gradient_setup: List[Tuple[bool, List[str]]] = None, + ) -> None: + + if gradient_setup is not None: + self.gradient_setup = gradient_setup + self.new_config = True + # TODO add option to specify quantization per layer + + def set_requires_grad(self, pl_module: pl.LightningModule): + # [["True","unet.a.b","c"],["True,[]"]] + + for selected_module_setup in self.gradient_setup: + for model_name, p in pl_module.named_parameters(): + grad_mode = selected_module_setup[0] == True + selected_module_path = selected_module_setup[1] + path_is_matching = True + model_name_selection = model_name + for selected_module in selected_module_path: + position = model_name_selection.find(selected_module) + if position == -1: + path_is_matching = False + continue + else: + shift = len(selected_module) + model_name_selection = model_name_selection[position+shift:] + if path_is_matching: + # if grad_mode: + # print( + # f"Setting gradient for {model_name} to {grad_mode}") + p.requires_grad = grad_mode diff --git a/t2v_enhanced/model/video_ldm.py b/t2v_enhanced/model/video_ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..13e8634789df9a7c808a8f1e90504e8915dcc0ae --- /dev/null +++ b/t2v_enhanced/model/video_ldm.py @@ -0,0 +1,327 @@ +from pathlib import Path +from typing import Any, Optional, Union, Callable + +import pytorch_lightning as pl +import torch +from diffusers import DDPMScheduler, DiffusionPipeline, AutoencoderKL, DDIMScheduler +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat + +from transformers import CLIPTextModel, CLIPTokenizer +from utils.video_utils import ResultProcessor, save_videos_grid, video_naming + +from t2v_enhanced.model import pl_module_params_controlnet + +from t2v_enhanced.model.diffusers_conditional.models.controlnet.controlnet import ControlNetModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel +from t2v_enhanced.model.diffusers_conditional.models.controlnet.pipeline_text_to_video_w_controlnet_synth import TextToVideoSDPipeline + +from t2v_enhanced.model.diffusers_conditional.models.controlnet.processor import set_use_memory_efficient_attention_xformers +from t2v_enhanced.model.diffusers_conditional.models.controlnet.mask_generator import MaskGenerator + +import warnings +# from warnings import warn +from t2v_enhanced.utils.iimage import IImage +from t2v_enhanced.utils.object_loader import instantiate_object +from t2v_enhanced.utils.object_loader import get_class + + +class VideoLDM(pl.LightningModule): + + def __init__(self, + inference_params: pl_module_params_controlnet.InferenceParams, + opt_params: pl_module_params_controlnet.OptimizerParams = None, + unet_params: pl_module_params_controlnet.UNetParams = None, + ): + super().__init__() + + self.inference_generator = torch.Generator(device=self.device) + + self.opt_params = opt_params + self.unet_params = unet_params + + print(f"Base pipeline from: {unet_params.pipeline_repo}") + print(f"Pipeline class {unet_params.pipeline_class}") + # load entire pipeline (unet, vq, text encoder,..) + state_dict_control_model = None + state_dict_fusion = None + state_dict_base_model = None + + if len(opt_params.load_trained_controlnet_from_ckpt) > 0: + state_dict_ckpt = torch.load(opt_params.load_trained_controlnet_from_ckpt, map_location=torch.device("cpu")) + state_dict_ckpt = state_dict_ckpt["state_dict"] + state_dict_control_model = dict(filter(lambda x: x[0].startswith("unet"), state_dict_ckpt.items())) + state_dict_control_model = {k.split("unet.")[1]: v for (k, v) in state_dict_control_model.items()} + + state_dict_fusion = dict(filter(lambda x: "cross_attention_merger" in x[0], state_dict_ckpt.items())) + state_dict_fusion = {k.split("base_model.")[1]: v for (k, v) in state_dict_fusion.items()} + del state_dict_ckpt + + state_dict_proj = None + state_dict_ckpt = None + + if hasattr(unet_params, "use_resampler") and unet_params.use_resampler: + num_queries = unet_params.num_frames if unet_params.num_frames > 1 else None + if unet_params.use_image_tokens_ctrl: + num_queries = unet_params.num_control_input_frames + assert unet_params.frame_expansion == "none" + image_encoder = self.unet_params.image_encoder + embedding_dim = image_encoder.embedding_dim + + resampler = instantiate_object(self.unet_params.resampler_cls, video_length=num_queries, embedding_dim=embedding_dim, input_tokens=image_encoder.num_tokens, num_layers=self.unet_params.resampler_merging_layers, aggregation=self.unet_params.aggregation) + + state_dict_proj = None + + self.resampler = resampler + self.image_encoder = image_encoder + + + noise_scheduler = DDPMScheduler.from_pretrained(self.unet_params.pipeline_repo, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(self.unet_params.pipeline_repo, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae") + base_model = UNet3DConditionModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="unet", low_cpu_mem_usage=False, device_map=None, merging_mode=self.unet_params.merging_mode_base, use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_main, use_fps_conditioning=self.opt_params.use_fps_conditioning, unet_params=unet_params) + + if state_dict_base_model is not None: + miss, unex = base_model.load_state_dict(state_dict_base_model, strict=False) + assert len(unex) == 0 + if len(miss) > 0: + warnings.warn(f"Missing keys when loading base_mode:{miss}") + del state_dict_base_model + if state_dict_fusion is not None: + miss, unex = base_model.load_state_dict(state_dict_fusion, strict=False) + assert len(unex) == 0 + del state_dict_fusion + + print("PIPE LOADING DONE") + self.noise_scheduler = noise_scheduler + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + + self.unet = ControlNetModel.from_unet( + unet=base_model, + conditioning_embedding_out_channels=unet_params.conditioning_embedding_out_channels, + downsample_controlnet_cond=unet_params.downsample_controlnet_cond, + num_frames=unet_params.num_frames if (unet_params.frame_expansion != "none" or self.unet_params.use_controlnet_mask) else unet_params.num_control_input_frames, + num_frame_conditioning=unet_params.num_control_input_frames, + frame_expansion=unet_params.frame_expansion, + pre_transformer_in_cond=unet_params.pre_transformer_in_cond, + num_tranformers=unet_params.num_tranformers, + vae=AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae"), + zero_conv_mode=unet_params.zero_conv_mode, + merging_mode=unet_params.merging_mode, + condition_encoder=unet_params.condition_encoder, + use_controlnet_mask=unet_params.use_controlnet_mask, + use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_ctrl, + unet_params=unet_params, + use_image_encoder_normalization=unet_params.use_image_encoder_normalization, + ) + if state_dict_control_model is not None: + miss, unex = self.unet.load_state_dict( + state_dict_control_model, strict=False) + if len(miss) > 0: + print("WARNING: Loading checkpoint for controlnet misses states") + print(miss) + + if unet_params.frame_expansion == "none": + attention_params = self.unet_params.attention_mask_params + assert not attention_params.temporal_self_attention_only_on_conditioning and not attention_params.spatial_attend_on_condition_frames and not attention_params.temp_attend_on_neighborhood_of_condition_frames + + self.mask_generator = MaskGenerator( + self.unet_params.attention_mask_params, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames) + self.mask_generator_base = MaskGenerator( + self.unet_params.attention_mask_params_base, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames) + + if state_dict_proj is not None and unet_params.use_image_tokens_main: + if unet_params.use_image_tokens_main: + missing, unexpected = base_model.load_state_dict( + state_dict_proj, strict=False) + elif unet_params.use_image_tokens_ctrl: + missing, unexpected = unet.load_state_dict( + state_dict_proj, strict=False) + assert len(unexpected) == 0, f"Unexpected entries {unexpected}" + print(f"Missing keys state proj = {missing}") + del state_dict_proj + + base_model.requires_grad_(False) + self.base_model = base_model + self.unet.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.vae.requires_grad_(False) + + layers_config = opt_params.layers_config + layers_config.set_requires_grad(self) + + print("CUSTOM XFORMERS ATTENTION USED.") + if is_xformers_available(): + set_use_memory_efficient_attention_xformers(self.unet, num_frame_conditioning=self.unet_params.num_control_input_frames, + num_frames=self.unet_params.num_frames, + attention_mask_params=self.unet_params.attention_mask_params + ) + set_use_memory_efficient_attention_xformers(self.base_model, num_frame_conditioning=self.unet_params.num_control_input_frames, + num_frames=self.unet_params.num_frames, + attention_mask_params=self.unet_params.attention_mask_params_base) + + if len(inference_params.scheduler_cls) > 0: + inf_scheduler_class = get_class(inference_params.scheduler_cls) + else: + inf_scheduler_class = DDIMScheduler + + inf_scheduler = inf_scheduler_class.from_pretrained( + self.unet_params.pipeline_repo, subfolder="scheduler") + inference_pipeline = TextToVideoSDPipeline(vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + unet=self.base_model, + controlnet=self.unet, + scheduler=inf_scheduler + ) + + inference_pipeline.set_noise_generator(self.opt_params.noise_generator) + inference_pipeline.enable_vae_slicing() + + inference_pipeline.set_progress_bar_config(disable=True) + + self.inference_params = inference_params + self.inference_pipeline = inference_pipeline + + self.result_processor = ResultProcessor(fps=self.inference_params.frame_rate, n_frames=self.inference_params.video_length) + + def on_start(self): + datamodule = self.trainer._data_connector._datahook_selector.datamodule + pipe_id_model = self.unet_params.pipeline_repo + for dataset_key in ["video_dataset", "image_dataset", "predict_dataset"]: + dataset = getattr(datamodule, dataset_key, None) + if dataset is not None and hasattr(dataset, "model_id"): + pipe_id_data = dataset.model_id + assert pipe_id_model == pipe_id_data, f"Model and Dataloader need the same pipeline path. Found '{pipe_id_model}' and '{dataset_key}.model_id={pipe_id_data}'. Consider setting '--data.{dataset_key}.model_id={pipe_id_data}'" + self.result_processor.set_logger(self.logger) + + def on_predict_start(self) -> None: + self.on_start() + # pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") + # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + # pipe.set_progress_bar_config(disable=True) + # self.first_stage = pipe.to(self.device) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + cfg = self.trainer.predict_cfg + + result_file_stem = cfg["result_file_stem"] + storage_fol = Path(cfg['predict_dir']) + prompts = [cfg["prompt"]] + + inference_params: pl_module_params_controlnet.InferenceParams = self.inference_params + conditioning_type = inference_params.conditioning_type + n_autoregressive_generations = inference_params.n_autoregressive_generations + mode = inference_params.mode + start_from_real_input = inference_params.start_from_real_input + assert isinstance(prompts, list) + + prompts = n_autoregressive_generations * prompts + + self.inference_generator.manual_seed(self.inference_params.seed) + + assert self.unet_params.num_control_input_frames == self.inference_params.video_length//2, f"currently we assume to have an equal size for and second half of the frame interval, e.g. 16 frames, and we condition on 8. Current setup: {self.unet_params.num_frame_conditioning} and {self.inference_params.video_length}" + + chunks_conditional = [] + batch_size = 1 + shape = (batch_size, self.inference_pipeline.unet.config.in_channels, self.inference_params.video_length, + self.inference_pipeline.unet.config.sample_size, self.inference_pipeline.unet.config.sample_size) + for idx, prompt in enumerate(prompts): + if idx > 0: + content = sample*2-1 + content_latent = self.vae.encode(content).latent_dist.sample() * self.vae.config.scaling_factor + content_latent = rearrange(content_latent, "F C W H -> 1 C F W H") + content_latent = content_latent[:, :, self.unet_params.num_control_input_frames:].detach().clone() + + if hasattr(self.inference_pipeline, "noise_generator"): + latents = self.inference_pipeline.noise_generator.sample_noise(shape=shape, device=self.device, dtype=self.dtype, generator=self.inference_generator, content=content_latent if idx > 0 else None) + else: + latents = None + if idx == 0: + sample = cfg["video"] + else: + if inference_params.conditioning_type == "fixed": + context = chunks_conditional[0][:self.unet_params.num_frame_conditioning] + context = [context] + context = [2*sample-1 for sample in context] + + input_frames_conditioning = torch.cat(context).detach().clone() + input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H") + elif inference_params.conditioning_type == "last_chunk": + input_frames_conditioning = condition_input[:, -self.unet_params.num_frame_conditioning:].detach().clone() + elif inference_params.conditioning_type == "past": + context = [sample[:self.unet_params.num_control_input_frames] for sample in chunks_conditional] + context = [2*sample-1 for sample in context] + + input_frames_conditioning = torch.cat(context).detach().clone() + input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H") + else: + raise NotImplementedError() + + input_frames = condition_input[:, self.unet_params.num_control_input_frames:].detach().clone() + + sample = self(prompt, input_frames=input_frames, input_frames_conditioning=input_frames_conditioning, latents=latents) + + if hasattr(self.inference_pipeline, "reset_noise_generator_state"): + self.inference_pipeline.reset_noise_generator_state() + + condition_input = rearrange(sample, "F C W H -> 1 F C W H") + condition_input = (2*condition_input)-1 # range: [-1,1] + + # store first 16 frames, then always last 8 of a chunk + chunks_conditional.append(sample) + + result_formats = self.inference_params.result_formats + # result_formats = [gif", "mp4"] + concat_video = self.inference_params.concat_video + + def IImage_normalized(x): return IImage(x, vmin=0, vmax=1) + for result_format in result_formats: + save_format = result_format.replace("eval_", "") + + merged_video = None + for chunk_idx, (prompt, video) in enumerate(zip(prompts, chunks_conditional)): + if chunk_idx == 0: + current_video = IImage_normalized(video) + else: + current_video = IImage_normalized(video[self.unet_params.num_control_input_frames:]) + + if merged_video is None: + merged_video = current_video + else: + merged_video &= current_video + + if concat_video: + filename = video_naming(prompts[0], save_format, batch_idx, 0) + result_file_video = (storage_fol / filename).absolute().as_posix() + result_file_video = (Path(result_file_video).parent / (result_file_stem+Path(result_file_video).suffix)).as_posix() + self.result_processor.save_to_file(video=merged_video.torch(vmin=0, vmax=1), prompt=prompts[0], video_filename=result_file_video, prompt_on_vid=False) + + def forward(self, prompt, input_frames=None, input_frames_conditioning=None, latents=None): + call_params = self.inference_params.to_dict() + print(f"INFERENCE PARAMS = {call_params}") + call_params["prompt"] = prompt + + call_params["image"] = input_frames + call_params["num_frames"] = self.inference_params.video_length + call_params["return_dict"] = False + call_params["output_type"] = "pt_t2v" + call_params["mask_generator"] = self.mask_generator + call_params["precision"] = "16" if self.trainer.precision.startswith("16") else "32" + call_params["no_text_condition_control"] = self.opt_params.no_text_condition_control + call_params["weight_control_sample"] = self.unet_params.weight_control_sample + call_params["use_controlnet_mask"] = self.unet_params.use_controlnet_mask + call_params["skip_controlnet_branch"] = self.opt_params.skip_controlnet_branch + call_params["img_cond_resampler"] = self.resampler if self.unet_params.use_resampler else None + call_params["img_cond_encoder"] = self.image_encoder if self.unet_params.use_resampler else None + call_params["input_frames_conditioning"] = input_frames_conditioning + call_params["cfg_text_image"] = self.unet_params.cfg_text_image + call_params["use_of"] = self.unet_params.use_of + if latents is not None: + call_params["latents"] = latents + + sample = self.inference_pipeline(generator=self.inference_generator, **call_params) + return sample diff --git a/t2v_enhanced/model/video_noise_generator.py b/t2v_enhanced/model/video_noise_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..49ffabe721a44c45242bbd3d7811925b948c3184 --- /dev/null +++ b/t2v_enhanced/model/video_noise_generator.py @@ -0,0 +1,225 @@ +import torch +import torch.fft as fft +from torch import nn +from torch.nn import functional +from math import sqrt +from einops import rearrange +import math +import numbers +from typing import List + +# adapted from https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 +# and https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/19 + + +def gaussian_smoothing_kernel(shape, kernel_size, sigma, dim=2): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + + kernel *= torch.exp(-((mgrid - mean) / std) ** 2 / 2) + # kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + # torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + pad_length = (math.floor( + (shape[-1]-kernel_size[-1])/2), math.floor((shape[-1]-kernel_size[-1])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-3]-kernel_size[-3])/2), math.floor((shape[-3]-kernel_size[-3])/2)) + + kernel = functional.pad(kernel, pad_length) + assert kernel.shape == shape[-3:] + return kernel + + ''' + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = functional.conv1d + elif dim == 2: + self.conv = functional.conv2d + elif dim == 3: + self.conv = functional.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( + dim) + ) + ''' + + +class NoiseGenerator(): + + def __init__(self, alpha: float = 0.0, shared_noise_across_chunks: bool = False, mode="vanilla", forward_steps: int = 850, radius: List[float] = None) -> None: + self.mode = mode + self.alpha = alpha + self.shared_noise_across_chunks = shared_noise_across_chunks + self.forward_steps = forward_steps + self.radius = radius + + def set_seed(self, seed: int): + self.seed = seed + + def reset_seed(self, seed: int): + pass + + def reset_noise_generator_state(self): + if hasattr(self, "e_shared"): + del self.e_shared + + def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None, content=None): + assert (z_0 is not None) != ( + shape is not None), f"either z_0 must be None, or shape must be None. Both provided." + kwargs = {} + noise = torch.randn(shape, **kwargs) + + if z_0 is None: + if device is not None: + kwargs["device"] = device + if dtype is not None: + kwargs["dtype"] = dtype + + else: + kwargs["device"] = z_0.device + kwargs["dtype"] = z_0.dtype + shape = z_0.shape + + if generator is not None: + kwargs["generator"] = generator + + B, F, C, W, H = shape + if F == 4 and C > 4: + frame_idx = 2 + F, C = C, F + else: + frame_idx = 1 + + if "mixed_noise" in self.mode: + + shape_per_frame = [dim for dim in shape] + shape_per_frame[frame_idx] = 1 + zero_mean = torch.zeros( + shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) + std = torch.ones( + shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) + alpha = self.alpha + std_coeff_shared = (alpha**2) / (1 + alpha**2) + if self.shared_noise_across_chunks and hasattr(self, "e_shared"): + e_shared = self.e_shared + else: + e_shared = torch.normal(mean=zero_mean, std=sqrt( + std_coeff_shared)*std, generator=kwargs["generator"] if "generator" in kwargs else None) + if self.shared_noise_across_chunks: + self.e_shared = e_shared + + e_inds = [] + for frame in range(shape[frame_idx]): + std_coeff_ind = 1 / (1 + alpha**2) + e_ind = torch.normal( + mean=zero_mean, std=sqrt(std_coeff_ind)*std, generator=kwargs["generator"] if "generator" in kwargs else None) + e_inds.append(e_ind) + noise = torch.cat( + [e_shared + e_ind for e_ind in e_inds], dim=frame_idx) + + if "consistI2V" in self.mode and content is not None: + # if self.mode == "mixed_noise_consistI2V", we will use 'noise' from 'mixed_noise'. Otherwise, it is randn noise. + + if frame_idx == 1: + assert content.shape[0] == noise.shape[0] and content.shape[2:] == noise.shape[2:] + content = torch.concat([content, content[:, -1:].repeat( + 1, noise.shape[1]-content.shape[1], 1, 1, 1)], dim=1) + noise = rearrange(noise, "B F C W H -> (B C) F W H") + content = rearrange(content, "B F C W H -> (B C) F W H") + + else: + assert content.shape[:2] == noise.shape[: + 2] and content.shape[3:] == noise.shape[3:] + content = torch.concat( + [content, content[:, :, -1:].repeat(1, 1, noise.shape[2]-content.shape[2], 1, 1)], dim=2) + noise = rearrange(noise, "B C F W H -> (B C) F W H") + content = rearrange(content, "B C F W H -> (B C) F W H") + + # TODO implement DDPM_forward using diffusers framework + ''' + content_noisy = ddpm_forward( + content, noise, self.forward_steps) + ''' + + # A 2D low pass filter was given in the blog: + # see https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/ + + # alternative + # do we have to specify more (s,dim,norm?) + noise_fft = fft.fftn(noise) + content_noisy_fft = fft.fftn(content_noisy) + + # shift low frequency parts to center + noise_fft_shifted = fft.fftshift(noise_fft) + content_noisy_fft_shifted = fft.fftshift(content_noisy_fft) + + # create gaussian low pass filter 'gaussian_low_pass_filter' (specify std!) + # mask out high frequencies using 'cutoff_frequence', something like gaussian_low_pass_filter[freq > cut_off_frequency] = 0.0 + # TODO define 'gaussian_low_pass_filter', apply frequency cutoff filter using self.cutoff_frequency. We need to apply fft.fftshift too probably. + # TODO what exactly is the "normalized space-time stop frequency" used for the cutoff? + + gaussian_3d = gaussian_smoothing_kernel(noise_fft.shape, kernel_size=( + noise_fft.shape[-3], noise_fft.shape[-2], noise_fft.shape[-1]), sigma=1, dim=3).to(noise.device) + + # define cutoff frequency around the kernel center + # TODO define center and cut off radius, e.g. somethink like gaussian_3d[...,:c_x-r_x,:c_y-r_y:,:c_z-r_z] = 0.0 and gaussian_3d[...,c_x+r_x:,c_y+r_y:,c_z+r_z:] = 0.0 + # as we have 16 x 32 x 32, center should be (7.5,15.5,15.5) + radius = self.radius + + # TODO we need to use rounding (ceil?) + + gaussian_3d[:center[0]-radius[0], :center[1] - + radius[1], :center[2]-radius[2]] = 0.0 + gaussian_3d[center[0]+radius[0]:, + center[1]+radius[1]:, center[2]+radius[2]:] = 0.0 + + noise_fft_shifted_hp = noise_fft_shifted * (1 - gaussian_3d) + content_noisy_fft_shifted_lp = content_noisy_fft_shifted * gaussian_3d + + noise = fft.ifftn(fft.ifftshift( + noise_fft_shifted_hp+content_noisy_fft_shifted_lp)) + if frame_idx == 1: + noise = rearrange( + noise, "(B C) F W H -> B F C W H", B=B) + else: + noise = rearrange( + noise, "(B C) F W H -> B C F W H", B=B) + + assert noise.shape == shape + return noise diff --git a/t2v_enhanced/model_func.py b/t2v_enhanced/model_func.py new file mode 100644 index 0000000000000000000000000000000000000000..0221678e35378ab1dca57476247e5792860eea63 --- /dev/null +++ b/t2v_enhanced/model_func.py @@ -0,0 +1,117 @@ +# General +import os +from os.path import join as opj +import datetime +import torch +from einops import rearrange, repeat + +# Utilities +from inference_utils import * + +from modelscope.outputs import OutputKeys +import imageio +from PIL import Image +import numpy as np + +import torch.nn.functional as F +import torchvision.transforms as transforms +from diffusers.utils import load_image +transform = transforms.Compose([ + transforms.PILToTensor() +]) + + +def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"): + frames = ms_model(prompt, + num_inference_steps=t, + generator=inference_generator, + eta=1.0, + height=256, + width=256, + latents=None).frames + frames = torch.stack([torch.from_numpy(frame) for frame in frames]) + frames = frames.to(device).to(torch.float32) + return rearrange(frames[0], "F W H C -> F C W H") + +def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"): + frames = ad_model(prompt, + negative_prompt="bad quality, worse quality", + num_frames=16, + num_inference_steps=t, + generator=inference_generator, + guidance_scale=7.5).frames[0] + frames = torch.stack([transform(frame) for frame in frames]) + frames = frames.to(device).to(torch.float32) + frames = F.interpolate(frames, size=256) + frames = frames/255.0 + return frames + +def sdxl_image_gen(prompt, sdxl_model): + image = sdxl_model(prompt=prompt).images[0] + return image + +def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"): + if image is None or image == "": + image = sdxl_image_gen(prompt, sdxl_model) + image = image.resize((576, 576)) + image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) + else: + image = load_image(image) + image = resize_and_keep(image) + image = center_crop(image) + image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) + + frames = svd_model(image, decode_chunk_size=8, generator=inference_generator).frames[0] + frames = torch.stack([transform(frame) for frame in frames]) + frames = frames.to(device).to(torch.float32) + frames = frames[:16,:,:,224:-224] + frames = F.interpolate(frames, size=256) + frames = frames/255.0 + return frames + + +def stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, result_file_stem, stream_cli, stream_model): + trainer = stream_cli.trainer + trainer.limit_predict_batches = 1 + trainer.predict_cfg = { + "predict_dir": stream_cli.config["result_fol"].as_posix(), + "result_file_stem": result_file_stem, + "prompt": prompt, + "video": short_video, + "seed": seed, + "num_inference_steps": t, + "guidance_scale": image_guidance, + 'n_autoregressive_generations': n_autoreg_gen, + } + + trainer.predict(model=stream_model, datamodule=stream_cli.datamodule) + + +def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True): + downscale = cfg_v2v['downscale'] + upscale_size = cfg_v2v['upscale_size'] + pad = cfg_v2v['pad'] + + now = datetime.datetime.now() + name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_") + enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4") + + video_frames = imageio.mimread(video) + h, w, _ = video_frames[0].shape + + # Downscale video, then resize to fit the upscale size + video = [Image.fromarray(frame).resize((w//downscale, h//downscale)) for frame in video_frames] + video = [resize_to_fit(frame, upscale_size) for frame in video] + + if pad: + video = [pad_to_fit(frame, upscale_size) for frame in video] + # video = [np.array(frame) for frame in video] + + imageio.mimsave(opj(where_to_log, 'temp.mp4'), video, fps=8) + + p_input = { + 'video_path': opj(where_to_log, 'temp.mp4'), + 'text': prompt + } + output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO] + return enhanced_video_mp4 diff --git a/t2v_enhanced/model_init.py b/t2v_enhanced/model_init.py new file mode 100644 index 0000000000000000000000000000000000000000..3d299881153fdc5de52a6a3378eb3a771dddf473 --- /dev/null +++ b/t2v_enhanced/model_init.py @@ -0,0 +1,112 @@ +# General +import sys +from pathlib import Path +import torch +from pytorch_lightning import LightningDataModule + +# For Stage-1 +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter +from diffusers import StableVideoDiffusionPipeline, AutoPipelineForText2Image + +# For Stage-2 +import tempfile +import yaml +from t2v_enhanced.model.video_ldm import VideoLDM +from model.callbacks import SaveConfigCallback +from inference_utils import legacy_transformation, remove_value, CustomCLI + +# For Stage-3 +from modelscope.pipelines import pipeline + + +# Initialize Stage-1 model1. +def init_modelscope(device="cuda"): + pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") + # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + # pipe.set_progress_bar_config(disable=True) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.enable_vae_slicing() + pipe.set_progress_bar_config(disable=True) + return pipe.to(device) + +def init_zeroscope(device="cuda"): + pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + return pipe.to(device) + +def init_animatediff(device="cuda"): + adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) + model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16) + scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + clip_sample=False, + timestep_spacing="linspace", + beta_schedule="linear", + steps_offset=1, + ) + pipe.scheduler = scheduler + pipe.enable_vae_slicing() + pipe.enable_model_cpu_offload() + return pipe.to(device) + +def init_sdxl(device="cuda"): + pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) + # pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) + return pipe.to(device) + +def init_svd(device="cuda"): + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.enable_model_cpu_offload() + return pipe.to(device) + + +# Initialize StreamingT2V model. +def init_streamingt2v_model(ckpt_file, result_fol): + config_file = "configs/text_to_video/config.yaml" + sys.argv = sys.argv[:1] + with tempfile.TemporaryDirectory() as tmpdirname: + storage_fol = Path(tmpdirname) + with open(config_file, "r") as yaml_handle: + yaml_obj = yaml.safe_load(yaml_handle) + + yaml_obj_orig_data_cfg = legacy_transformation(yaml_obj) + yaml_obj_orig_data_cfg = remove_value(yaml_obj_orig_data_cfg, "video_dataset") + + with open(storage_fol / 'config.yaml', 'w') as outfile: + yaml.dump(yaml_obj_orig_data_cfg, outfile, default_flow_style=False) + sys.argv.append("--config") + sys.argv.append((storage_fol / 'config.yaml').as_posix()) + sys.argv.append("--ckpt") + sys.argv.append(ckpt_file.as_posix()) + sys.argv.append("--result_fol") + sys.argv.append(result_fol.as_posix()) + sys.argv.append("--config") + sys.argv.append("configs/inference/inference_long_video.yaml") + sys.argv.append("--data.prompt_cfg.type=prompt") + sys.argv.append(f"--data.prompt_cfg.content='test prompt for initialization'") + sys.argv.append("--trainer.devices=1") + sys.argv.append("--trainer.num_nodes=1") + sys.argv.append(f"--model.inference_params.num_inference_steps=50") + sys.argv.append(f"--model.inference_params.n_autoregressive_generations=4") + sys.argv.append("--model.inference_params.concat_video=True") + sys.argv.append("--model.inference_params.result_formats=[eval_mp4]") + + cli = CustomCLI(VideoLDM, LightningDataModule, run=False, subclass_mode_data=True, + auto_configure_optimizers=False, parser_kwargs={"parser_mode": "omegaconf"}, save_config_callback=SaveConfigCallback, save_config_kwargs={"log_dir": result_fol, "overwrite": True}) + + model = cli.model + model.load_state_dict(torch.load( + cli.config["ckpt"].as_posix())["state_dict"]) + return cli, model + + +# Initialize Stage-3 model. +def init_v2v_model(cfg): + model_id = cfg['model_id'] + pipe_enhance = pipeline(task="video-to-video", model=model_id, model_revision='v1.1.0', device='cuda') + return pipe_enhance diff --git a/t2v_enhanced/utils/conversions.py b/t2v_enhanced/utils/conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb8ba1b7ebb402aa9fdf77c7bf02f68f0d951ce --- /dev/null +++ b/t2v_enhanced/utils/conversions.py @@ -0,0 +1,48 @@ +from pathlib import Path +import PIL +from PIL import Image +import numpy as np +from dataclasses import dataclass + +# TODO add register new converter so that it is accessible via converters.to_x + +def ensure_class(func, params): + def func_wrapper(function): + def wrapper(self=None, *args, **kwargs): + for key in kwargs: + if key in params: + kwargs[key] = func(kwargs[key]) + if self is not None: + return function(self, *args, **kwargs) + else: + return function(*args, **kwargs) + + return wrapper + + return func_wrapper + + +def as_PIL(img): + if not isinstance(img, PIL.Image.Image): + if isinstance(img, Path): + img = img.as_posix() + if isinstance(img, str): + img = Image.open(img) + elif isinstance(img, np.ndarray): + img = Image.fromarray(img) + + else: + raise NotImplementedError + return img + + +def to_ndarray(input): + if not isinstance(input, np.ndarray): + input = np.array(input) + return input + + +def to_Path(input): + if not isinstance(input, Path): + input = Path(input) + return input diff --git a/t2v_enhanced/utils/iimage.py b/t2v_enhanced/utils/iimage.py new file mode 100644 index 0000000000000000000000000000000000000000..17be19ab7631f12bc1422ea037270a292ae9a79d --- /dev/null +++ b/t2v_enhanced/utils/iimage.py @@ -0,0 +1,517 @@ +import io +import math +import os +import PIL.Image +import numpy as np +import imageio.v3 as iio +import warnings + + +import torch +import torchvision.transforms.functional as TF +from scipy.ndimage import binary_dilation, binary_erosion +import cv2 + +import re + +import matplotlib.pyplot as plt +from matplotlib import animation +from IPython.display import HTML, Image, display + + +IMG_THUMBSIZE = None + +def torch2np(x, vmin=-1, vmax=1): + if x.ndim != 4: + # raise Exception("Please only use (B,C,H,W) torch tensors!") + warnings.warn( + "Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!") + if x.ndim == 3: + x = x[None] + if x.ndim == 2: + x = x[None, None] + x = x.detach().cpu().float() + if x.dtype == torch.uint8: + return x.numpy().astype(np.uint8) + elif vmin is not None and vmax is not None: + x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin)) + x = x.permute(0, 2, 3, 1).to(torch.uint8) + return x.numpy() + else: + raise NotImplementedError() + + +class IImage: + ''' + Generic media storage. Can store both images and videos. + Stores data as a numpy array by default. + Can be viewed in a jupyter notebook. + ''' + @staticmethod + def open(path): + + iio_obj = iio.imopen(path, 'r') + data = iio_obj.read() + try: + # .properties() does not work for images but for gif files + if not iio_obj.properties().is_batch: + data = data[None] + except AttributeError as e: + # this one works for gif files + if not "duration" in iio_obj.metadata(): + data = data[None] + if data.ndim == 3: + data = data[..., None] + image = IImage(data) + image.link = os.path.abspath(path) + return image + + @staticmethod + def normalized(x, dims=[-1, -2]): + x = (x - x.amin(dims, True)) / \ + (x.amax(dims, True) - x.amin(dims, True)) + return IImage(x, 0) + + def numpy(self): return self.data + + def torch(self, vmin=-1, vmax=1): + if self.data.ndim == 3: + data = self.data.transpose(2, 0, 1) / 255. + else: + data = self.data.transpose(0, 3, 1, 2) / 255. + return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin) + + def cuda(self): + self.device = 'cuda' + return self + + def cpu(self): + self.device = 'cpu' + return self + + def pil(self): + ans = [] + for x in self.data: + if x.shape[-1] == 1: + x = x[..., 0] + + ans.append(PIL.Image.fromarray(x)) + if len(ans) == 1: + return ans[0] + return ans + + def is_iimage(self): + return True + + @property + def shape(self): return self.data.shape + @property + def size(self): return (self.data.shape[-2], self.data.shape[-3]) + + def setFps(self, fps): + self.fps = fps + self.generate_display() + return self + + def __init__(self, x, vmin=-1, vmax=1, fps=None): + if isinstance(x, PIL.Image.Image): + self.data = np.array(x) + if self.data.ndim == 2: + self.data = self.data[..., None] # (H,W,C) + self.data = self.data[None] # (B,H,W,C) + elif isinstance(x, IImage): + self.data = x.data.copy() # Simple Copy + elif isinstance(x, np.ndarray): + self.data = x.copy().astype(np.uint8) + if self.data.ndim == 2: + self.data = self.data[None, ..., None] + if self.data.ndim == 3: + warnings.warn( + "Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)") + self.data = self.data[None] + elif isinstance(x, torch.Tensor): + self.data = torch2np(x, vmin, vmax) + self.display_str = None + self.device = 'cpu' + self.fps = fps if fps is not None else ( + 1 if len(self.data) < 10 else 30) + self.link = None + + def generate_display(self): + if IMG_THUMBSIZE is not None: + if self.size[1] < self.size[0]: + thumb = self.resize( + (self.size[1]*IMG_THUMBSIZE//self.size[0], IMG_THUMBSIZE)) + else: + thumb = self.resize( + (IMG_THUMBSIZE, self.size[0]*IMG_THUMBSIZE//self.size[1])) + else: + thumb = self + if self.is_video(): + self.anim = Animation(thumb.data, fps=self.fps) + self.anim.render() + self.display_str = self.anim.anim_str + else: + b = io.BytesIO() + data = thumb.data[0] + if data.shape[-1] == 1: + data = data[..., 0] + PIL.Image.fromarray(data).save(b, "PNG") + self.display_str = b.getvalue() + return self.display_str + + def resize(self, size, *args, **kwargs): + if size is None: + return self + use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False) + + # Backward compatibility + resample = kwargs.pop('filter', PIL.Image.BICUBIC) + resample = kwargs.pop('resample', resample) + + if isinstance(size, int): + if use_small_edge_when_int: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (max(size, int(size * aspect_ratio)), + max(size, int(size / aspect_ratio))) + else: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (min(size, int(size * aspect_ratio)), + min(size, int(size / aspect_ratio))) + + if self.size == size[::-1]: + return self + return stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self]) + + def pad(self, padding, *args, **kwargs): + return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0) + + def padx(self, multiplier, *args, **kwargs): + size = np.array(self.size) + padding = np.concatenate( + [[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size]) + return self.pad(list(padding), *args, **kwargs) + + def pad2wh(self, w=0, h=0, **kwargs): + cw, ch = self.size + return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs) + + def pad2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs) + return self + + def crop2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs) + return self + + def alpha(self): + return IImage(self.data[..., -1, None], fps=self.fps) + + def rgb(self): + return IImage(self.pil().convert('RGB'), fps=self.fps) + + def png(self): + return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1)) + + def grid(self, nrows=None, ncols=None): + if nrows is not None: + ncols = math.ceil(self.data.shape[0] / nrows) + elif ncols is not None: + nrows = math.ceil(self.data.shape[0] / ncols) + else: + warnings.warn( + "No dimensions specified, creating a grid with 5 columns (default)") + ncols = 5 + nrows = math.ceil(self.data.shape[0] / ncols) + + pad = nrows * ncols - self.data.shape[0] + data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0))) + rows = [np.concatenate(x, 1, dtype=np.uint8) + for x in np.array_split(data, nrows)] + return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None]) + + def hstack(self): + return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None]) + + def vstack(self): + return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None]) + + def vsplit(self, number_of_splits): + return IImage(np.concatenate(np.split(self.data, number_of_splits, 1))) + + def hsplit(self, number_of_splits): + return IImage(np.concatenate(np.split(self.data, number_of_splits, 2))) + + def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET): + data = np.stack([cv2.cvtColor(cv2.applyColorMap( + x, cmap), cv2.COLOR_BGR2RGB) for x in self.data]) + return IImage(data).resize(resize, use_small_edge_when_int=True) + + def display(self): + try: + display(self) + except: + print("No display") + return self + + def dilate(self, iterations=1, *args, **kwargs): + if iterations == 0: + return IImage(self.data) + return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def erode(self, iterations=1, *args, **kwargs): + return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def hull(self): + convex_hulls = [] + for frame in self.data: + contours, hierarchy = cv2.findContours( + frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + contours = [x.astype(np.int32) for x in contours] + mask_contours = [cv2.convexHull(np.concatenate(contours))] + canvas = np.zeros(self.data[0].shape, np.uint8) + convex_hull = cv2.drawContours( + canvas, mask_contours, -1, (255, 0, 0), -1) + convex_hulls.append(convex_hull) + return IImage(np.array(convex_hulls)) + + def is_video(self): + return self.data.shape[0] > 1 + + def __getitem__(self, idx): + return IImage(self.data[None, idx], fps=self.fps) + # if self.is_video(): return IImage(self.data[idx], fps = self.fps) + # return self + + def _repr_png_(self): + if self.is_video(): + return None + if self.display_str is None: + self.generate_display() + return self.display_str + + def _repr_html_(self): + if not self.is_video(): + return None + if self.display_str is None: + self.generate_display() + return self.display_str + + def save(self, path): + _, ext = os.path.splitext(path) + if self.is_video(): + # if ext in ['.jpg', '.png']: + if self.display_str is None: + self.generate_display() + if ext == ".apng": + self.anim.anim_obj.save(path, writer="pillow") + else: + self.anim.anim_obj.save(path) + else: + data = self.data if self.data.ndim == 3 else self.data[0] + if data.shape[-1] == 1: + data = data[:, :, 0] + PIL.Image.fromarray(data).save(path) + return self + + def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2): + if not isinstance(text, list): + text = [text for _ in self.data] + data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX, + font_scale, color, thickness) for x, t in zip(self.data, text)]) + return IImage(data) + + def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0): + + assert np.count_nonzero(padding) == 1 + axis_padding = np.nonzero(padding)[0][0] + scale_padding = padding[axis_padding] + + y_0 = 0 + x_0 = 0 + if axis_padding == 0: + width = scale_padding + y_max = self.shape[1] + elif axis_padding == 1: + width = self.shape[2] + y_max = scale_padding + elif axis_padding == 2: + x_0 = self.shape[2] + width = scale_padding + y_max = self.shape[1] + elif axis_padding == 3: + width = self.shape[2] + y_0 = self.shape[1] + y_max = self.shape[1]+scale_padding + + width -= center[0] + x_0 += center[0] + y_0 += center[1] + + self = self.pad(padding, fill=fill) + + def wrap_text(text, width, _font_scale): + allowed_seperator = ' |-|_|/|\n' + words = re.split(allowed_seperator, text) + # words = text.split() + lines = [] + current_line = words[0] + sep_list = [] + start_idx = 0 + for start_word in words[:-1]: + pos = text.find(start_word, start_idx) + pos += len(start_word) + sep_list.append(text[pos]) + start_idx = pos+1 + + for word, separator in zip(words[1:], sep_list): + if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + current_line += separator + word + else: + if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + lines.append(current_line) + current_line = word + else: + return [] + + if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + lines.append(current_line) + else: + return [] + return lines + + def wrap_text_and_scale(text, width, _font_scale, y_0, y_max): + height = y_max+1 + while height > y_max: + text_lines = wrap_text(text, width, _font_scale) + if len(text) > 0 and len(text_lines) == 0: + + height = y_max+1 + else: + line_height = cv2.getTextSize( + text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1] + height = line_height * len(text_lines) + y_0 + + # scale font if out of frame + if height > y_max: + _font_scale = _font_scale * scale_factor + + return text_lines, line_height, _font_scale + + result = [] + if not isinstance(text, list): + text = [text for _ in self.data] + else: + assert len(text) == len(self.data) + + for x, t in zip(self.data, text): + x = x.copy() + text_lines, line_height, _font_scale = wrap_text_and_scale( + t, width, font_scale, y_0, y_max) + y = line_height + for line in text_lines: + x = cv2.putText( + x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness) + y += line_height + result.append(x) + data = np.stack(result) + + return IImage(data) + + # ========== OPERATORS ============= + + def __or__(self, other): + # TODO: fix for variable sizes + return IImage(np.concatenate([self.data, other.data], 2)) + + def __truediv__(self, other): + # TODO: fix for variable sizes + return IImage(np.concatenate([self.data, other.data], 1)) + + def __and__(self, other): + return IImage(np.concatenate([self.data, other.data], 0)) + + def __add__(self, other): + return IImage(0.5 * self.data + 0.5 * other.data) + + def __mul__(self, other): + if isinstance(other, IImage): + return IImage(self.data / 255. * other.data) + return IImage(self.data * other / 255.) + + def __xor__(self, other): + return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0)) + + def __invert__(self): + return IImage(255 - self.data) + __rmul__ = __mul__ + + def bbox(self): + return [cv2.boundingRect(x) for x in self.data] + + def fill_bbox(self, bbox_list, fill=255): + data = self.data.copy() + for bbox in bbox_list: + x, y, w, h = bbox + data[:, y:y+h, x:x+w, :] = fill + return IImage(data) + + def crop(self, bbox): + assert len(bbox) in [2, 4] + if len(bbox) == 2: + x, y = 0, 0 + w, h = bbox + elif len(bbox) == 4: + x, y, w, h = bbox + return IImage(self.data[:, y:y+h, x:x+w, :]) + +def stack(images, axis = 0): + return IImage(np.concatenate([x.data for x in images], axis)) + +class Animation: + JS = 0 + HTML = 1 + ANIMATION_MODE = HTML + def __init__(self, frames, fps = 30): + """_summary_ + + Args: + frames (np.ndarray): _description_ + """ + self.frames = frames + self.fps = fps + self.anim_obj = None + self.anim_str = None + def render(self): + size = (self.frames.shape[2],self.frames.shape[1]) + self.fig = plt.figure(figsize = size, dpi = 1) + plt.axis('off') + img = plt.imshow(self.frames[0], cmap = 'gray') + self.fig.subplots_adjust(0,0,1,1) + self.anim_obj = animation.FuncAnimation( + self.fig, + lambda i: img.set_data(self.frames[i,:,:,:]), + frames=self.frames.shape[0], + interval = 1000 / self.fps + ) + plt.close() + if Animation.ANIMATION_MODE == Animation.HTML: + self.anim_str = self.anim_obj.to_html5_video() + elif Animation.ANIMATION_MODE == Animation.JS: + self.anim_str = self.anim_obj.to_jshtml() + return self.anim_obj + def _repr_html_(self): + if self.anim_obj is None: self.render() + return self.anim_str \ No newline at end of file diff --git a/t2v_enhanced/utils/image_converter.py b/t2v_enhanced/utils/image_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..6891da9ed39bacea8699599f76037727e07a5156 --- /dev/null +++ b/t2v_enhanced/utils/image_converter.py @@ -0,0 +1,45 @@ +import cv2 +import numpy as np +from albumentations.augmentations.geometric import functional as F +from albumentations.core.transforms_interface import DualTransform + +__all__ = ["ProportionalMinScale"] + + +class ProportionalMinScale(DualTransform): + + def __init__( + self, + width: int, + height: int, + interpolation: int = cv2.INTER_LINEAR, + always_apply: bool = False, + p: float = 1, + ): + super(ProportionalMinScale, self).__init__(always_apply, p) + self.width = width + self.height = height + + def apply( + self, img: np.ndarray, width: int = 256, height: int = 256, interpolation: int = cv2.INTER_LINEAR, **params): + h_img, w_img, _ = img.shape + + min_side = np.min([h_img, w_img]) + + if (height/h_img)*w_img >= width: + if h_img == min_side: + return F.smallest_max_size(img, max_size=height, interpolation=interpolation) + else: + return F.longest_max_size(img, max_size=height, interpolation=interpolation) + if (width/w_img)*h_img >= height: + if w_img == min_side: + return F.smallest_max_size(img, max_size=width, interpolation=interpolation) + else: + return F.longest_max_size(img, max_size=width, interpolation=interpolation) + return F.longest_max_size(img, max_size=width, interpolation=interpolation) + + def get_params(self): + return {"width": self.width, "height": self.height} + + def get_transform_init_args_names(self): + return ("width", "height", "intepolation") diff --git a/t2v_enhanced/utils/object_loader.py b/t2v_enhanced/utils/object_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a121cea14154c4e61c39e6bd961737e735f75879 --- /dev/null +++ b/t2v_enhanced/utils/object_loader.py @@ -0,0 +1,26 @@ +import importlib +from functools import partialmethod + + +def instantiate_object(cls_path: str, *args, **kwargs): + class_ = get_class(cls_path, *args, **kwargs) + obj = class_() + return obj + + +def get_class(cls_path: str, *args, **kwargs): + module_name = ".".join(cls_path.split(".")[:-1]) + module = importlib.import_module(module_name) + + class_ = getattr(module, cls_path.split(".")[-1]) + class_.__init__ = partialmethod(class_.__init__, *args, **kwargs) + return class_ + + +if __name__ == "__main__": + + class_ = get_class( + "diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler") + scheduler = class_.from_config("stabilityai/stable-diffusion-2-1", + subfolder="scheduler") + print(scheduler) diff --git a/t2v_enhanced/utils/video_utils.py b/t2v_enhanced/utils/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c10db14ffe831d77c22eb22be10635e625813871 --- /dev/null +++ b/t2v_enhanced/utils/video_utils.py @@ -0,0 +1,376 @@ +import os +import subprocess +import tempfile +from pathlib import Path +from typing import Union +import shutil + +import cv2 +import imageio +import numpy as np +import torch +import torchvision +from decord import VideoReader, cpu +from einops import rearrange, repeat +from t2v_enhanced.utils.iimage import IImage +from PIL import Image, ImageDraw, ImageFont +from torchvision.utils import save_image + +channel_first = 0 +channel_last = -1 + + +def video_naming(prompt, extension, batch_idx, idx): + prompt_identifier = prompt.replace(" ", "_") + prompt_identifier = prompt_identifier.replace("/", "_") + if len(prompt_identifier) > 40: + prompt_identifier = prompt_identifier[:40] + filename = f"{batch_idx:04d}_{idx:04d}_{prompt_identifier}.{extension}" + return filename + + +def video_naming_chunk(prompt, extension, batch_idx, idx, chunk_idx): + prompt_identifier = prompt.replace(" ", "_") + prompt_identifier = prompt_identifier.replace("/", "_") + if len(prompt_identifier) > 40: + prompt_identifier = prompt_identifier[:40] + filename = f"{batch_idx}_{idx}_{chunk_idx}_{prompt_identifier}.{extension}" + return filename + + +class ResultProcessor(): + + def __init__(self, fps: int, n_frames: int, logger=None) -> None: + self.fps = fps + self.logger = logger + self.n_frames = n_frames + + def set_logger(self, logger): + self.logger = logger + + def _create_video(self, video, prompt, filename: Union[str, Path], append_video: torch.FloatTensor = None, input_flow=None): + + if video.ndim == 5: + # can be batches if we provide list of filenames + assert video.shape[0] == 1 + video = video[0] + + if video.shape[0] == 3 and video.shape[1] == self.n_frames: + video = rearrange(video, "C F W H -> F C W H") + assert video.shape[1] == 3, f"Wrong video format. Got {video.shape}" + if isinstance(filename, Path): + filename = filename.as_posix() + # assert video.max() <= 1 and video.min() >= 0 + assert video.max() <=1.1 and video.min() >= -0.1, f"video has unexpected range: [{video.min()}, {video.max()}]" + vid_obj = IImage(video, vmin=0, vmax=1) + + if prompt is not None: + vid_obj = vid_obj.append_text(prompt, padding=(0, 50, 0, 0)) + + if append_video is not None: + if append_video.ndim == 5: + assert append_video.shape[0] == 1 + append_video = append_video[0] + if append_video.shape[0] < video.shape[0]: + append_video = torch.concat([append_video, + repeat(append_video[-1, None], "F C W H -> (rep F) C W H", rep=video.shape[0]-append_video.shape[0])], dim=0) + if append_video.ndim == 3 and video.ndim == 4: + append_video = repeat( + append_video, "C W H -> F C W H", F=video.shape[0]) + append_video = IImage(append_video, vmin=-1, vmax=1) + if prompt is not None: + append_video = append_video.append_text( + "input_frame", padding=(0, 50, 0, 0)) + vid_obj = vid_obj | append_video + vid_obj = vid_obj.setFps(self.fps) + vid_obj.save(filename) + + def _create_prompt_file(self, prompt, filename, video_path: str = None): + filename = Path(filename) + filename = filename.parent / (filename.stem+".txt") + + with open(filename.as_posix(), "w") as file_writer: + file_writer.write(prompt) + file_writer.write("\n") + if video_path is not None: + file_writer.write(video_path) + else: + file_writer.write(" no_source") + + def log_video(self, video: torch.FloatTensor, prompt: str, video_id: str, log_folder: str, input_flow=None, video_path_input: str = None, extension: str = "gif", prompt_on_vid: bool = True, append_video: torch.FloatTensor = None): + + with tempfile.TemporaryDirectory() as tmpdirname: + storage_fol = Path(tmpdirname) + filename = f"{video_id}.{extension}".replace("/", "_") + vid_filename = storage_fol / filename + self._create_video( + video, prompt if prompt_on_vid else None, vid_filename, append_video, input_flow=input_flow) + + prompt_file = storage_fol / f"{video_id}.txt" + self._create_prompt_file(prompt, prompt_file, video_path_input) + + if self.logger.experiment.__class__.__name__ == "_DummyExperiment": + run_fol = Path(self.logger.save_dir) / \ + self.logger.experiment_id / self.logger.run_id / "artifacts" / log_folder + if not run_fol.exists(): + run_fol.mkdir(parents=True, exist_ok=True) + shutil.copy(prompt_file.as_posix(), + (run_fol / f"{video_id}.txt").as_posix()) + shutil.copy(vid_filename, + (run_fol / filename).as_posix()) + else: + self.logger.experiment.log_artifact( + self.logger.run_id, prompt_file.as_posix(), log_folder) + self.logger.experiment.log_artifact( + self.logger.run_id, vid_filename, log_folder) + + def save_to_file(self, video: torch.FloatTensor, prompt: str, video_filename: Union[str, Path], input_flow=None, conditional_video_path: str = None, prompt_on_vid: bool = True, conditional_video: torch.FloatTensor = None): + self._create_video( + video, prompt if prompt_on_vid else None, video_filename, conditional_video, input_flow=input_flow) + self._create_prompt_file( + prompt, video_filename, conditional_video_path) + + +def add_text_to_image(image_array, text, position, font_size, text_color, font_path=None): + + # Convert the NumPy array to PIL Image + image_pil = Image.fromarray(image_array) + + # Create a drawing object + draw = ImageDraw.Draw(image_pil) + + if font_path is not None: + font = ImageFont.truetype(font_path, font_size) + else: + try: + # Load the font + font = ImageFont.truetype( + "/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", font_size) + except: + font = ImageFont.load_default() + + # Draw the text on the image + draw.text(position, text, font=font, fill=text_color) + + # Convert the PIL Image back to NumPy array + modified_image_array = np.array(image_pil) + + return modified_image_array + + +def add_text_to_video(video_path, prompt): + + outputs_with_overlay = [] + with open(video_path, "rb") as f: + vr = VideoReader(f, ctx=cpu(0)) + + for i in range(len(vr)): + frame = vr[i] + frame = add_text_to_image(frame, prompt, position=( + 10, 10), font_size=15, text_color=(255, 0, 0),) + outputs_with_overlay.append(frame) + outputs = outputs_with_overlay + video_path = video_path.replace("mp4", "gif") + imageio.mimsave(video_path, outputs, duration=100, loop=0) + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=30, prompt=None): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + if prompt is not None: + outputs_with_overlay = [] + for frame in outputs: + frame_out = add_text_to_image( + frame, prompt, position=(10, 10), font_size=10, text_color=(255, 0, 0),) + outputs_with_overlay.append(frame_out) + outputs = outputs_with_overlay + imageio.mimsave(path, outputs, duration=round(1/fps*1000), loop=0) + # iio.imwrite(path, outputs) + # optimize(path) + + +def set_channel_pos(data, shape_dict, channel_pos): + + assert data.ndim == 5 or data.ndim == 4 + batch_dim = data.shape[0] + frame_dim = shape_dict["frame_dim"] + channel_dim = shape_dict["channel_dim"] + width_dim = shape_dict["width_dim"] + height_dim = shape_dict["height_dim"] + + assert batch_dim != frame_dim + assert channel_dim != frame_dim + assert channel_dim != batch_dim + + video_shape = list(data.shape) + batch_pos = video_shape.index(batch_dim) + + channel_pos = video_shape.index(channel_dim) + w_pos = video_shape.index(width_dim) + h_pos = video_shape.index(height_dim) + if w_pos == h_pos: + video_shape[w_pos] = -1 + h_pos = video_shape.index(height_dim) + pattern_order = {} + pattern_order[batch_pos] = "B" + pattern_order[channel_pos] = "C" + + pattern_order[w_pos] = "W" + pattern_order[h_pos] = "H" + + if data.ndim == 5: + frame_pos = video_shape.index(frame_dim) + pattern_order[frame_pos] = "F" + if channel_pos == channel_first: + pattern = " -> B F C W H" + else: + pattern = " -> B F W H C" + else: + if channel_pos == channel_first: + pattern = " -> B C W H" + else: + pattern = " -> B W H C" + pattern_input = [pattern_order[idx] for idx in range(data.ndim)] + pattern_input = " ".join(pattern_input) + pattern = pattern_input + pattern + data = rearrange(data, pattern) + + +def merge_first_two_dimensions(tensor): + dims = tensor.ndim + letters = [] + for letter_idx in range(dims-2): + letters.append(chr(letter_idx+67)) + latters_pattern = " ".join(letters) + tensor = rearrange(tensor, "A B "+latters_pattern + + " -> (A B) "+latters_pattern) + # TODO merging first two dimensions might be easier with reshape so no need to create letters + # should be 'tensor.view(*tensor.shape[:2], -1)' + return tensor + + +def apply_spatial_function_to_video_tensor(video, shape, func): + # TODO detect batch, frame, channel, width, and height + + assert video.ndim == 5 + batch_dim = shape["batch_dim"] + frame_dim = shape["frame_dim"] + channel_dim = shape["channel_dim"] + width_dim = shape["width_dim"] + height_dim = shape["height_dim"] + + assert batch_dim != frame_dim + assert channel_dim != frame_dim + assert channel_dim != batch_dim + + video_shape = list(video.shape) + batch_pos = video_shape.index(batch_dim) + frame_pos = video_shape.index(frame_dim) + channel_pos = video_shape.index(channel_dim) + w_pos = video_shape.index(width_dim) + h_pos = video_shape.index(height_dim) + if w_pos == h_pos: + video_shape[w_pos] = -1 + h_pos = video_shape.index(height_dim) + pattern_order = {} + pattern_order[batch_pos] = "B" + pattern_order[channel_pos] = "C" + pattern_order[frame_pos] = "F" + pattern_order[w_pos] = "W" + pattern_order[h_pos] = "H" + pattern_order = sorted(pattern_order.items(), key=lambda x: x[1]) + pattern_order = [x[0] for x in pattern_order] + input_pattern = " ".join(pattern_order) + video = rearrange(video, input_pattern+" -> (B F) C W H") + + video = func(video) + video = rearrange(video, "(B F) C W H -> "+input_pattern, F=frame_dim) + return video + + +def dump_frames(videos, as_mosaik, storage_fol, save_image_kwargs): + + # assume videos is in format B F C H W, range [0,1] + num_frames = videos.shape[1] + num_videos = videos.shape[0] + + if videos.shape[2] != 3 and videos.shape[-1] == 3: + videos = rearrange(videos, "B F W H C -> B F C W H") + + frame_counter = 0 + if not isinstance(storage_fol, Path): + storage_fol = Path(storage_fol) + + for frame_idx in range(num_frames): + print(f" Creating frame {frame_idx}") + batch_frame = videos[:, frame_idx, ...] + + if as_mosaik: + filename = storage_fol / f"frame_{frame_counter:03d}.png" + save_image(batch_frame, fp=filename.as_posix(), + **save_image_kwargs) + frame_counter += 1 + else: + for video_idx in range(num_videos): + frame = batch_frame[video_idx] + + filename = storage_fol / f"frame_{frame_counter:03d}.png" + save_image(frame, fp=filename.as_posix(), + **save_image_kwargs) + frame_counter += 1 + + +def gif_from_videos(videos): + + assert videos.dim() == 5 + assert videos.min() >= 0 + assert videos.max() <= 1 + gif_file = Path("tmp.gif").absolute() + + with tempfile.TemporaryDirectory() as tmpdirname: + storage_fol = Path(tmpdirname) + nrows = min(4, videos.shape[0]) + dump_frames( + videos=videos, storage_fol=storage_fol, as_mosaik=True, save_image_kwargs={"nrow": nrows}) + cmd = f"ffmpeg -y -f image2 -framerate 4 -i {storage_fol / 'frame_%03d.png'} {gif_file.as_posix()}" + subprocess.check_call( + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) + return gif_file + + + +def add_margin(pil_img, top, right, bottom, left, color): + width, height = pil_img.size + new_width = width + right + left + new_height = height + top + bottom + result = Image.new(pil_img.mode, (new_width, new_height), color) + result.paste(pil_img, (left, top)) + return result + +def resize_to_fit(image, size): + W, H = size + w, h = image.size + if H / h > W / w: + H_ = int(h * W / w) + W_ = W + else: + W_ = int(w * H / h) + H_ = H + return image.resize((W_, H_)) + +def pad_to_fit(image, size): + W, H = size + w, h = image.size + pad_h = (H - h) // 2 + pad_w = (W - w) // 2 + return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) \ No newline at end of file diff --git a/t2v_enhanced/utils/visualisation.py b/t2v_enhanced/utils/visualisation.py new file mode 100644 index 0000000000000000000000000000000000000000..1a749cb955f27a029645ef2c4f2a2f2e5f199317 --- /dev/null +++ b/t2v_enhanced/utils/visualisation.py @@ -0,0 +1,139 @@ +from collections import defaultdict +import torch +from torchvision.utils import make_grid +from torchvision.transforms import ToPILImage +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.colors import Normalize +from matplotlib import cm + +def pil_concat_v(images): + width = images[0].width + height = sum([image.height for image in images]) + dst = Image.new('RGB', (width, height)) + h = 0 + for image_idx, image in enumerate(images): + dst.paste(image, (0, h)) + h += image.height + return dst + +def pil_concat_h(images): + width = sum([image.width for image in images]) + height = images[0].height + dst = Image.new('RGB', (width, height)) + w = 0 + for image_idx, image in enumerate(images): + dst.paste(image, (w, 0)) + w += image.width + return dst + +def add_label(image, text, fontsize=12): + dst = Image.new('RGB', (image.width, image.height + fontsize*3)) + dst.paste(image, (0, 0)) + draw = ImageDraw.Draw(dst) + font = ImageFont.truetype("../misc/fonts/OpenSans.ttf", fontsize) + draw.text((fontsize, image.height + fontsize),text,(255,255,255),font=font) + return dst + +def pil_concat(images, labels=None, col=8, fontsize=12): + col = min(col, len(images)) + if labels is not None: + labeled_images = [add_label(image, labels[image_idx], fontsize=fontsize) for image_idx, image in enumerate(images)] + else: + labeled_images = images + labeled_images_rows = [] + for row_idx in range(int(np.ceil(len(labeled_images) / col))): + labeled_images_rows.append(pil_concat_h(labeled_images[col*row_idx:col*(row_idx+1)])) + return pil_concat_v(labeled_images_rows) + + +def draw_panoptic_segmentation(model, segmentation, segments_info): + # get the used color map + viridis = cm.get_cmap('viridis') + norm = Normalize(vmin=segmentation.min().item(), vmax=segmentation.max().item()) + fig, ax = plt.subplots() + ax.imshow(segmentation, cmap=viridis, norm=norm) + instances_counter = defaultdict(int) + handles = [] + for segment in segments_info: + segment_id = segment['id'] + segment_label_id = segment['label_id'] + segment_label = model.config.id2label[segment_label_id] + label = f"{segment_label}-{instances_counter[segment_label_id]}" + instances_counter[segment_label_id] += 1 + color = viridis(norm(segment_id)) + handles.append(mpatches.Patch(color=color, label=label)) + ax.legend(handles=handles) + + + +rescale_ = lambda x: (x + 1.) / 2. + +def pil_grid_display(x, mask=None, nrow=4, rescale=True): + if rescale: + x = rescale_(x) + if mask is not None: + mask = mask_to_3_channel(mask) + x = torch.concat([mask, x]) + grid = make_grid(torch.clip(x, 0, 1), nrow=nrow) + return ToPILImage()(grid) + +def pil_display(x, rescale=True): + if rescale: + x = rescale_(x) + image = torch.clip(rescale_(x), 0, 1) + return ToPILImage()(image) + +def mask_to_3_channel(mask): + if mask.dim() == 3: + mask_c_idx = 0 + elif mask.dim() == 4: + mask_c_idx = 1 + else: + raise Exception("mask should be a 3d or 4d tensor") + + if mask.shape[mask_c_idx] == 3: + return mask + elif mask.shape[mask_c_idx] == 1: + sizes = [1] * mask.dim() + sizes[mask_c_idx] = 3 + mask = mask.repeat(*sizes) + else: + raise Exception("mask should have size 1 in channel dim") + return mask + + +def get_first_k_token_head_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): + n_heads = atts_normed.shape[0] + att_images = [] + for head_idx in range(n_heads): + atts_head = atts_normed[head_idx, :, :k].reshape(h, w, k).movedim(2, 0) + for token_idx in range(k): + att_head_np = atts_head[token_idx].detach().cpu().numpy() + if max_scale: + att_head_np = att_head_np / att_head_np.max() + att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) + att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) + att_images.append(att_image) + return pil_concat(att_images, col=k, labels=None) + +def get_first_k_token_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): + att_images = [] + atts_head = atts_normed.mean(0)[:, :k].reshape(h, w, k).movedim(2, 0) + for token_idx in range(k): + att_head_np = atts_head[token_idx].detach().cpu().numpy() + if max_scale: + att_head_np = att_head_np / att_head_np.max() + att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) + att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) + att_images.append(att_image) + return pil_concat(att_images, col=k, labels=None) + +def draw_bbox(image, bbox): + image = image.copy() + left, top, right, bottom = bbox + image_draw = ImageDraw.Draw(image) + image_draw.rectangle(((left, top),(right, bottom)), outline='Red') + return image \ No newline at end of file