gaparmar commited on
Commit
13ed5cd
·
1 Parent(s): d3864b2

sketch demo

Browse files
Files changed (2) hide show
  1. app.py +4 -8
  2. src/pix2pix_turbo.py +37 -24
app.py CHANGED
@@ -1,7 +1,3 @@
1
- """
2
- 3.43.1
3
- """
4
-
5
  import os
6
  import sys
7
  import pdb
@@ -78,7 +74,8 @@ def run(image, prompt, prompt_template, style_name, seed, val_r):
78
  print("sketch updated")
79
  if image is None:
80
  ones = Image.new("L", (512, 512), 255)
81
- return ones
 
82
  prompt = prompt_template.replace("{prompt}", prompt)
83
  image = image.convert("RGB")
84
  image_t = TF.to_tensor(image) > 0.5
@@ -234,8 +231,8 @@ with gr.Blocks(css="style.css") as demo:
234
  <div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
235
  </div>
236
  """)
237
- gr.Markdown("## Prompt", elem_id="tools_header")
238
- prompt = gr.Textbox(label=None, value="", show_label=False)
239
  with gr.Row():
240
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
241
  prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
@@ -269,4 +266,3 @@ with gr.Blocks(css="style.css") as demo:
269
 
270
  if __name__ == "__main__":
271
  demo.queue().launch(debug=True)
272
-
 
 
 
 
 
1
  import os
2
  import sys
3
  import pdb
 
74
  print("sketch updated")
75
  if image is None:
76
  ones = Image.new("L", (512, 512), 255)
77
+ temp_uri = pil_image_to_data_uri(ones)
78
+ return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
79
  prompt = prompt_template.replace("{prompt}", prompt)
80
  image = image.convert("RGB")
81
  image_t = TF.to_tensor(image) > 0.5
 
231
  <div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
232
  </div>
233
  """)
234
+ # gr.Markdown("## Prompt", elem_id="tools_header")
235
+ prompt = gr.Textbox(label="Prompt", value="", show_label=True)
236
  with gr.Row():
237
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
238
  prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
 
266
 
267
  if __name__ == "__main__":
268
  demo.queue().launch(debug=True)
 
src/pix2pix_turbo.py CHANGED
@@ -1,4 +1,6 @@
1
- import os, requests
 
 
2
  import pdb
3
  import copy
4
  from tqdm import tqdm
@@ -7,11 +9,13 @@ from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
7
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
8
  from diffusers.utils.peft_utils import set_weights_and_activate_adapters
9
  from peft import LoraConfig
10
- from .model import make_1step_sched
 
 
11
 
12
 
 
13
  def my_vae_encoder_fwd(self, sample):
14
- r"""The forward method of the `Encoder` class."""
15
  sample = self.conv_in(sample)
16
  l_blocks = []
17
  # down
@@ -27,6 +31,7 @@ def my_vae_encoder_fwd(self, sample):
27
  return sample
28
 
29
 
 
30
  def my_vae_decoder_fwd(self,sample, latent_embeds = None):
31
  sample = self.conv_in(sample)
32
  upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
@@ -76,21 +81,33 @@ class Pix2Pix_Turbo(torch.nn.Module):
76
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
77
  unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
78
 
79
- if name=="canny_to_image":
80
- lora_rank = 8
81
- P_UNET_SD="/home/gparmar/code/single_step_translation/output/paired/canny_canny_midjourney_512_512/sd21_turbo_direct_edge_withskip_opt_lora_8_proj/l2_lpips_gan_vagan_clip_224_patch_multilevel_sigmoid/lr_5e-5_l2_0.25_lpips_1_0.1_CLIPSIM_1.0/1node_8gpu_no_BS_1_GRAD_ACC_2/checkpoint-7501/unet_sd.pkl"
82
- P_VAE_ENC_SD="/home/gparmar/code/single_step_translation/output/paired/canny_canny_midjourney_512_512/sd21_turbo_direct_edge_withskip_opt_lora_8_proj/l2_lpips_gan_vagan_clip_224_patch_multilevel_sigmoid/lr_5e-5_l2_0.25_lpips_1_0.1_CLIPSIM_1.0/1node_8gpu_no_BS_1_GRAD_ACC_2/checkpoint-7501/sd_vae_enc.pkl"
83
- P_VAE_DEC_SD="/home/gparmar/code/single_step_translation/output/paired/canny_canny_midjourney_512_512/sd21_turbo_direct_edge_withskip_opt_lora_8_proj/l2_lpips_gan_vagan_clip_224_patch_multilevel_sigmoid/lr_5e-5_l2_0.25_lpips_1_0.1_CLIPSIM_1.0/1node_8gpu_no_BS_1_GRAD_ACC_2/checkpoint-7501/sd_vae_dec.pkl"
84
- unet_lora_config = LoraConfig(r=lora_rank, init_lora_weights="gaussian", target_modules=[
85
- "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
86
- "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"]
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if name=="sketch_to_image_stochastic":
90
  # download from url
91
- url = "https://www.cs.cmu.edu/~clean-fid/tmp/img2img_turbo/ckpt/sketch_to_image_stochastic.pkl"
92
  os.makedirs(ckpt_folder, exist_ok=True)
93
- outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic.pkl")
94
  if not os.path.exists(outf):
95
  print(f"Downloading checkpoint to {outf}")
96
  response = requests.get(url, stream=True)
@@ -105,7 +122,6 @@ class Pix2Pix_Turbo(torch.nn.Module):
105
  if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
106
  print("ERROR, something went wrong")
107
  print(f"Downloaded successfully to {outf}")
108
- # p_ckpt = "/home/gparmar/code/img2img-turbo/single_step_translation/notebooks/DEMO/sketch_to_image_stochastic.pkl"
109
  p_ckpt = outf
110
  sd = torch.load(p_ckpt, map_location="cpu")
111
  unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
@@ -123,15 +139,17 @@ class Pix2Pix_Turbo(torch.nn.Module):
123
  vae.decoder.ignore_skip = False
124
  vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
125
  unet.add_adapter(unet_lora_config)
126
- unet.load_state_dict(sd["state_dict_unet"])
 
 
127
  unet.enable_xformers_memory_efficient_attention()
128
-
129
- vae.load_state_dict(sd["state_dict_vae"])
 
130
  unet.to("cuda")
131
  vae.to("cuda")
132
  unet.eval()
133
  vae.eval()
134
-
135
  self.unet, self.vae = unet, vae
136
  self.timesteps = torch.tensor([999], device="cuda").long()
137
 
@@ -141,7 +159,6 @@ class Pix2Pix_Turbo(torch.nn.Module):
141
  caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
142
  padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
143
  caption_enc = self.text_encoder(caption_tokens)[0]
144
-
145
  if deterministic:
146
  encoded_control = self.vae.encode(c_t).latent_dist.sample()*self.vae.config.scaling_factor
147
  model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
@@ -161,8 +178,4 @@ class Pix2Pix_Turbo(torch.nn.Module):
161
  x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
162
  self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
163
  output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
164
-
165
  return output_image
166
-
167
-
168
-
 
1
+ import os
2
+ import requests
3
+ import sys
4
  import pdb
5
  import copy
6
  from tqdm import tqdm
 
9
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
10
  from diffusers.utils.peft_utils import set_weights_and_activate_adapters
11
  from peft import LoraConfig
12
+ p = "src/"
13
+ sys.path.append(p)
14
+ from model import make_1step_sched
15
 
16
 
17
+ """The forward method of the `Encoder` class."""
18
  def my_vae_encoder_fwd(self, sample):
 
19
  sample = self.conv_in(sample)
20
  l_blocks = []
21
  # down
 
31
  return sample
32
 
33
 
34
+ """The forward method of the `Decoder` class."""
35
  def my_vae_decoder_fwd(self,sample, latent_embeds = None):
36
  sample = self.conv_in(sample)
37
  upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
 
81
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
82
  unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
83
 
84
+ if name=="edge_to_image":
85
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
86
+ os.makedirs(ckpt_folder, exist_ok=True)
87
+ outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
88
+ if not os.path.exists(outf):
89
+ print(f"Downloading checkpoint to {outf}")
90
+ response = requests.get(url, stream=True)
91
+ total_size_in_bytes= int(response.headers.get('content-length', 0))
92
+ block_size = 1024 # 1 Kibibyte
93
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
94
+ with open(outf, 'wb') as file:
95
+ for data in response.iter_content(block_size):
96
+ progress_bar.update(len(data))
97
+ file.write(data)
98
+ progress_bar.close()
99
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
100
+ print("ERROR, something went wrong")
101
+ print(f"Downloaded successfully to {outf}")
102
+ p_ckpt = outf
103
+ sd = torch.load(p_ckpt, map_location="cpu")
104
+ unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
105
 
106
  if name=="sketch_to_image_stochastic":
107
  # download from url
108
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
109
  os.makedirs(ckpt_folder, exist_ok=True)
110
+ outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
111
  if not os.path.exists(outf):
112
  print(f"Downloading checkpoint to {outf}")
113
  response = requests.get(url, stream=True)
 
122
  if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
123
  print("ERROR, something went wrong")
124
  print(f"Downloaded successfully to {outf}")
 
125
  p_ckpt = outf
126
  sd = torch.load(p_ckpt, map_location="cpu")
127
  unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
 
139
  vae.decoder.ignore_skip = False
140
  vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
141
  unet.add_adapter(unet_lora_config)
142
+ _sd_unet = unet.state_dict()
143
+ for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k]
144
+ unet.load_state_dict(_sd_unet)
145
  unet.enable_xformers_memory_efficient_attention()
146
+ _sd_vae = vae.state_dict()
147
+ for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k]
148
+ vae.load_state_dict(_sd_vae)
149
  unet.to("cuda")
150
  vae.to("cuda")
151
  unet.eval()
152
  vae.eval()
 
153
  self.unet, self.vae = unet, vae
154
  self.timesteps = torch.tensor([999], device="cuda").long()
155
 
 
159
  caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
160
  padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
161
  caption_enc = self.text_encoder(caption_tokens)[0]
 
162
  if deterministic:
163
  encoded_control = self.vae.encode(c_t).latent_dist.sample()*self.vae.config.scaling_factor
164
  model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
 
178
  x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
179
  self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
180
  output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
 
181
  return output_image