Spaces:
Running
Running
import torch | |
from diffusers.loaders.lora import LoraLoaderMixin | |
from typing import Dict, Union | |
import numpy as np | |
import imageio | |
def load_lora_weights(unet, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name = None, **kwargs): | |
# if a dict is passed, copy it instead of modifying it inplace | |
if isinstance(pretrained_model_name_or_path_or_dict, dict): | |
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | |
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. | |
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | |
# remove prefix if not removed when saved | |
state_dict = {name.replace('base_model.model.', ''): param for name, param in state_dict.items()} | |
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) | |
if not is_correct_format: | |
raise ValueError("Invalid LoRA checkpoint.") | |
low_cpu_mem_usage = True | |
LoraLoaderMixin.load_lora_into_unet( | |
state_dict, | |
network_alphas=network_alphas, | |
unet = unet, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
adapter_name=adapter_name, | |
) | |
def save_video(frames, save_path, fps, quality=9): | |
# Ensure the file is saved with a video format | |
if not save_path.endswith(('.mp4', '.avi', '.gif')): | |
save_path += '.mp4' # Default to mp4 if no valid extension is provided | |
writer = imageio.get_writer(save_path, fps=fps, quality=quality) | |
for frame in frames: | |
frame = np.array(frame) | |
writer.append_data(frame) | |
writer.close() |