mkshing commited on
Commit
eba4833
·
verified ·
1 Parent(s): 4009728

Upload evosdxl_jp_v1.py

Browse files
Files changed (1) hide show
  1. evosdxl_jp_v1.py +204 -0
evosdxl_jp_v1.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Union
3
+ from tqdm import tqdm
4
+ import torch
5
+ import safetensors
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
8
+ from diffusers import (
9
+ StableDiffusionXLPipeline,
10
+ UNet2DConditionModel,
11
+ EulerDiscreteScheduler,
12
+ )
13
+ from diffusers.loaders import LoraLoaderMixin
14
+
15
+ SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
17
+ L_REPO = "ByteDance/SDXL-Lightning"
18
+
19
+
20
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
21
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
22
+ if file_extension == "safetensors":
23
+ return safetensors.torch.load_file(checkpoint_file, device=device)
24
+ else:
25
+ return torch.load(checkpoint_file, map_location=device)
26
+
27
+
28
+ def load_from_pretrained(
29
+ repo_id,
30
+ filename="diffusion_pytorch_model.fp16.safetensors",
31
+ subfolder="unet",
32
+ device="cuda",
33
+ ) -> Dict[str, torch.Tensor]:
34
+ return load_state_dict(
35
+ hf_hub_download(
36
+ repo_id=repo_id,
37
+ filename=filename,
38
+ subfolder=subfolder,
39
+ ),
40
+ device=device,
41
+ )
42
+
43
+
44
+ def reshape_weight_task_tensors(task_tensors, weights):
45
+ """
46
+ Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions.
47
+
48
+ Args:
49
+ task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
50
+ weights (`torch.Tensor`): The tensor to be reshaped.
51
+
52
+ Returns:
53
+ `torch.Tensor`: The reshaped tensor.
54
+ """
55
+ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
56
+ weights = weights.view(new_shape)
57
+ return weights
58
+
59
+
60
+ def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Merge the task tensors using `linear`.
63
+
64
+ Args:
65
+ task_tensors(`List[torch.Tensor]`):The task tensors to merge.
66
+ weights (`torch.Tensor`):The weights of the task tensors.
67
+
68
+ Returns:
69
+ `torch.Tensor`: The merged tensor.
70
+ """
71
+ task_tensors = torch.stack(task_tensors, dim=0)
72
+ # weighted task tensors
73
+ weights = reshape_weight_task_tensors(task_tensors, weights)
74
+ weighted_task_tensors = task_tensors * weights
75
+ mixed_task_tensors = weighted_task_tensors.sum(dim=0)
76
+ return mixed_task_tensors
77
+
78
+
79
+ def merge_models(
80
+ task_tensors,
81
+ weights,
82
+ ):
83
+ keys = list(task_tensors[0].keys())
84
+ weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
85
+ state_dict = {}
86
+ for key in tqdm(keys, desc="Merging"):
87
+ w_list = []
88
+ for i, sd in enumerate(task_tensors):
89
+ w = sd.pop(key)
90
+ w_list.append(w)
91
+ new_w = linear(task_tensors=w_list, weights=weights)
92
+ state_dict[key] = new_w
93
+ return state_dict
94
+
95
+
96
+ def split_conv_attn(weights):
97
+ attn_tensors = {}
98
+ conv_tensors = {}
99
+ for key in list(weights.keys()):
100
+ if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
101
+ attn_tensors[key] = weights.pop(key)
102
+ else:
103
+ conv_tensors[key] = weights.pop(key)
104
+ return {"conv": conv_tensors, "attn": attn_tensors}
105
+
106
+
107
+ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
108
+ sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
109
+ dpo_weights = split_conv_attn(
110
+ load_from_pretrained(
111
+ "mhdang/dpo-sdxl-text2image-v1",
112
+ "diffusion_pytorch_model.safetensors",
113
+ device=device,
114
+ )
115
+ )
116
+ jn_weights = split_conv_attn(
117
+ load_from_pretrained("RunDiffusion/Juggernaut-XL-v9", device=device)
118
+ )
119
+ jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
120
+ tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
121
+ new_conv = merge_models(
122
+ [sd["conv"] for sd in tensors],
123
+ [
124
+ 0.15928833971605916,
125
+ 0.1032449268871776,
126
+ 0.6503217149752791,
127
+ 0.08714501842148402,
128
+ ],
129
+ )
130
+ new_attn = merge_models(
131
+ [sd["attn"] for sd in tensors],
132
+ [
133
+ 0.1877279276437178,
134
+ 0.20014114603909822,
135
+ 0.3922685507065275,
136
+ 0.2198623756106564,
137
+ ],
138
+ )
139
+ del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
140
+ torch.cuda.empty_cache()
141
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
142
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
143
+ unet.load_state_dict({**new_conv, **new_attn})
144
+ state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
145
+ L_REPO, weight_name="sdxl_lightning_4step_lora.safetensors"
146
+ )
147
+ LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet)
148
+ unet.fuse_lora(lora_scale=3.224682864579401)
149
+ new_weights = split_conv_attn(unet.state_dict())
150
+ l_weights = split_conv_attn(
151
+ load_from_pretrained(
152
+ L_REPO,
153
+ "sdxl_lightning_4step_unet.safetensors",
154
+ subfolder=None,
155
+ device=device,
156
+ )
157
+ )
158
+ jnl_weights = split_conv_attn(
159
+ load_from_pretrained(
160
+ "RunDiffusion/Juggernaut-XL-Lightning",
161
+ "diffusion_pytorch_model.bin",
162
+ device=device,
163
+ )
164
+ )
165
+ tensors = [l_weights, jnl_weights, new_weights]
166
+ new_conv = merge_models(
167
+ [sd["conv"] for sd in tensors],
168
+ [0.47222002022088533, 0.48419531030361584, 0.04358466947549889],
169
+ )
170
+ new_attn = merge_models(
171
+ [sd["attn"] for sd in tensors],
172
+ [0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
173
+ )
174
+ new_weights = {**new_conv, **new_attn}
175
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
176
+ unet.load_state_dict({**new_conv, **new_attn})
177
+
178
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
179
+ JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
180
+ )
181
+ tokenizer = AutoTokenizer.from_pretrained(
182
+ JSDXL_REPO, subfolder="tokenizer", use_fast=False
183
+ )
184
+
185
+ pipe = StableDiffusionXLPipeline.from_pretrained(
186
+ SDXL_REPO,
187
+ unet=unet,
188
+ text_encoder=text_encoder,
189
+ tokenizer=tokenizer,
190
+ torch_dtype=torch.float16,
191
+ variant="fp16",
192
+ )
193
+ # Ensure sampler uses "trailing" timesteps.
194
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
195
+ pipe.scheduler.config, timestep_spacing="trailing"
196
+ )
197
+ pipe = pipe.to(device, dtype=torch.float16)
198
+ return pipe
199
+
200
+
201
+ if __name__ == "__main__":
202
+ pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
203
+ images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
204
+ images[0].save("out.png")