Ravi21 commited on
Commit
1abc402
·
verified ·
1 Parent(s): 0f2f6a9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +472 -0
app.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import gc
3
+
4
+ pipe = StableDiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to("cpu")
5
+ text_encoder = pipe.text_encoder
6
+ text_encoder.eval()
7
+ unet = pipe.unet
8
+ unet.eval()
9
+ vae = pipe.vae
10
+ vae.eval()
11
+
12
+ del pipe
13
+ gc.collect()
14
+
15
+ from pathlib import Path
16
+ import torch
17
+ import openvino as pv
18
+
19
+ text_encoder_path=Path("text_encoder.xml")
20
+
21
+ def cleanup_cache():
22
+ torch._C._jit_clear_class_registry()
23
+ torch.jit._recursive.concrete_type_store=torch.jit._recursive.ConcreteTypeStore()
24
+ torch.jit._state._clear_class_state()
25
+
26
+ def convert_encoder(text_encoder:torch.nn.Module,ir_path:Path):
27
+ """
28
+ Convert Text Encoder mode.
29
+ Function accepts text encoder model, and prepares example inputs for conversion,
30
+ Parameters:
31
+ text_encoder (torch.nn.Module): text_encoder model from Stable Diffusion pipeline
32
+ ir_path (Path): File for storing model
33
+ Returns:
34
+ None
35
+ """
36
+
37
+ input_ids=torch.ones((1,77),dtype=torch.long)
38
+ text_encoder.eval()
39
+
40
+ with torch.no_grad():
41
+ ov_model=pv.convert_model(text_encoder,example_input=input_ids,input=[(1,77),])
42
+ pv.save_model(ov_model,ir_path)
43
+ del ov_model
44
+ cleanup_cache()
45
+ print(f"Text Encoder successfully converted to TR and saved to {ir_path}")
46
+
47
+ if not text_encoder_path.exists():
48
+ convert_encoder(text_encoder,text_encoder_path)
49
+ else:
50
+ print(f"Text encoder will be loaded from {text_encoder_path}")
51
+ del text_encoder
52
+ gc.collect()
53
+
54
+ import numpy as np
55
+ unet_path=Path("unet.xml")
56
+ dtype_mapping={
57
+ torch.float32: pv.Type.f32,
58
+ torch.float64: pv.Type.f64
59
+ }
60
+
61
+ def convert_unet(unet:torch.nn.Module,ir_path:Path):
62
+ """
63
+ Convert U-net model to IR format.
64
+ Function accepts unet model, prepares example inputs for conversion,
65
+ Parameters:
66
+ unet (StableDiffusionPipeline): unet from Stable Diffusion pipeline
67
+ ir_path (Path): File for storing model
68
+ Returns:
69
+ None
70
+ """
71
+
72
+ encoder_hidden_state=torch.ones((2,77,768))
73
+ latents_shape=(2,4,512 // 8,512 // 8)
74
+ latents=torch.randn(latents_shape)
75
+ t=torch.from_numpy(np.array(1,dtype=float))
76
+ dummy_inputs=(latents,t,encoder_hidden_state)
77
+ input_info=[]
78
+ for input_tensor in dummy_inputs:
79
+ shape=pv.PartialShape(tuple(input_tensor.shape))
80
+ element_type=dtype_mapping[input_tensor.dtype]
81
+ input_info.append((shape,element_type))
82
+
83
+ unet.eval()
84
+ with torch.no_grad():
85
+ pv_model=pv.convert_model(unet,example_input=dummy_inputs,input=input_info)
86
+ pv.save_model(pv_model,ir_path)
87
+ del pv_model
88
+ cleanup_cache()
89
+ print(f"Unet successfully converted to IR and saved to {ir_path}")
90
+
91
+ if not unet_path.exists():
92
+ convert_unet(unet,unet_path)
93
+ gc.collect()
94
+ else:
95
+ print(f"unet will be loaded from {unet_path}")
96
+ del unet
97
+ gc.collect()
98
+
99
+ VAE_ENCODER_PATH = Path("vae_encoder.xml")
100
+
101
+ def convert_vae_encoder(vae: torch.nn.Module, ir_path: Path):
102
+ class VAEEncoder(torch.nn.Module):
103
+ def __init__(self, vae):
104
+ super().__init__()
105
+ self.vae = vae
106
+
107
+ def forward(self, image):
108
+ return self.vae.encode(x=image)["latent_dist"].sample()
109
+ vae_encoder = VAEEncoder(vae)
110
+ vae_encoder.eval()
111
+ image = torch.zeros((1, 3, 512, 512))
112
+ with torch.no_grad():
113
+ ov_model = pv.convert_model(vae_encoder, example_input=image, input=[((1,3,512,512),)])
114
+ pv.save_model(ov_model, ir_path)
115
+ del ov_model
116
+ cleanup_cache()
117
+ print(f'VAE encoder successfully converted to IR and saved to {ir_path}')
118
+
119
+
120
+ if not VAE_ENCODER_PATH.exists():
121
+ convert_vae_encoder(vae, VAE_ENCODER_PATH)
122
+ else:
123
+ print(f"VAE encoder will be loaded from {VAE_ENCODER_PATH}")
124
+
125
+ VAE_DECODER_PATH = Path('vae_decoder.xml')
126
+
127
+ def convert_vae_decoder(vae: torch.nn.Module, ir_path: Path):
128
+ class VAEDecoder(torch.nn.Module):
129
+ def __init__(self, vae):
130
+ super().__init__()
131
+ self.vae = vae
132
+
133
+ def forward(self, latents):
134
+ return self.vae.decode(latents)
135
+
136
+ vae_decoder = VAEDecoder(vae)
137
+ latents = torch.zeros((1, 4, 64, 64))
138
+
139
+ vae_decoder.eval()
140
+ with torch.no_grad():
141
+ ov_model = pv.convert_model(vae_decoder, example_input=latents, input=[((1,4,64,64),)])
142
+ pv.save_model(ov_model, ir_path)
143
+ del ov_model
144
+ cleanup_cache()
145
+ print(f'VAE decoder successfully converted to IR and saved to {ir_path}')
146
+
147
+
148
+ if not VAE_DECODER_PATH.exists():
149
+ convert_vae_decoder(vae, VAE_DECODER_PATH)
150
+ else:
151
+ print(f"VAE decoder will be loaded from {VAE_DECODER_PATH}")
152
+
153
+ del vae
154
+ gc.collect()
155
+
156
+
157
+ import inspect
158
+ from typing import List,Optional,Union,Dict
159
+ import PIL
160
+ import cv2
161
+
162
+ from transformers import CLIPTokenizer
163
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
164
+ from diffusers.schedulers import DDIMScheduler,LMSDiscreteScheduler,PNDMScheduler
165
+ from openvino.runtime import Model
166
+
167
+ def scale_window(dst_width:int,dst_height:int,image_width:int,image_height:int):
168
+ im_scale=min(dst_height / image_height,dst_width / image_width)
169
+ return int(im_scale * image_width), int(im_scale * image_height)
170
+ def preprocess(image:PIL.Image.Image):
171
+ src_width,src_height=image.size
172
+ dst_width,dst_height=scale_window(512,512,src_width,src_height)
173
+ image=np.array(image.resize((dst_width,dst_height),resample=PIL.Image.Resampling.LANCZOS))[None,:]
174
+ pad_width=512-dst_width
175
+ pad_height=512-dst_height
176
+ pad=((0,0),(0,pad_height),(0,pad_width),(0,0))
177
+ image=np.pad(image,pad,mode="constant")
178
+ image=image.astype(np.float32) / 255.0
179
+ image=2.0* image - 1.0
180
+ image=image.transpose(0,3,1,2)
181
+ return image, {"padding":pad,"src_width":src_width,"src_height":src_height}
182
+
183
+ class OVStableDiffusionPipeline(DiffusionPipeline):
184
+ def __init__(
185
+ self,
186
+ vae_decoder: Model,
187
+ text_encoder: Model,
188
+ tokenizer: CLIPTokenizer,
189
+ unet: Model,
190
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
191
+ vae_encoder: Model = None,
192
+ ):
193
+ super().__init__()
194
+ self.scheduler = scheduler
195
+ self.vae_decoder = vae_decoder
196
+ self.vae_encoder = vae_encoder
197
+ self.text_encoder = text_encoder
198
+ self.unet = unet
199
+ self._text_encoder_output = text_encoder.output(0)
200
+ self._unet_output = unet.output(0)
201
+ self._vae_d_output = vae_decoder.output(0)
202
+ self._vae_e_output = vae_encoder.output(0) if vae_encoder is not None else None
203
+ self.height = 512
204
+ self.width = 512
205
+ self.tokenizer = tokenizer
206
+
207
+ def __call__(
208
+ self,
209
+ prompt: Union[str, List[str]],
210
+ image: PIL.Image.Image = None,
211
+ num_inference_steps: Optional[int] = 50,
212
+ negative_prompt: Union[str, List[str]] = None,
213
+ guidance_scale: Optional[float] = 7.5,
214
+ eta: Optional[float] = 0.0,
215
+ output_type: Optional[str] = "pil",
216
+ seed: Optional[int] = None,
217
+ strength: float = 1.0,
218
+ gif: Optional[bool] = False,
219
+ **kwargs,
220
+ ):
221
+ if seed is not None:
222
+ np.random.seed(seed)
223
+
224
+ img_buffer = []
225
+ do_classifier_free_guidance = guidance_scale > 1.0
226
+ # get prompt text embeddings
227
+ text_embeddings = self._encode_prompt(prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt)
228
+
229
+ # set timesteps
230
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
231
+ extra_set_kwargs = {}
232
+ if accepts_offset:
233
+ extra_set_kwargs["offset"] = 1
234
+
235
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
236
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
237
+ latent_timestep = timesteps[:1]
238
+
239
+ # get the initial random noise unless the user supplied it
240
+ latents, meta = self.prepare_latents(image, latent_timestep)
241
+
242
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
243
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
244
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
245
+ # and should be between [0, 1]
246
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
247
+ extra_step_kwargs = {}
248
+ if accepts_eta:
249
+ extra_step_kwargs["eta"] = eta
250
+
251
+ for i, t in enumerate(self.progress_bar(timesteps)):
252
+ # expand the latents if you are doing classifier free guidance
253
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
254
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
255
+
256
+ # predict the noise residual
257
+ noise_pred = self.unet([latent_model_input, t, text_embeddings])[self._unet_output]
258
+ # perform guidance
259
+ if do_classifier_free_guidance:
260
+ noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
261
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
262
+
263
+ # compute the previous noisy sample x_t -> x_t-1
264
+ latents = self.scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
265
+ if gif:
266
+ image = self.vae_decoder(latents * (1 / 0.18215))[self._vae_d_output]
267
+ image = self.postprocess_image(image, meta, output_type)
268
+ img_buffer.extend(image)
269
+
270
+ # scale and decode the image latents with vae
271
+ image = self.vae_decoder(latents * (1 / 0.18215))[self._vae_d_output]
272
+
273
+ image = self.postprocess_image(image, meta, output_type)
274
+ return {"sample": image, 'iterations': img_buffer}
275
+
276
+ def _encode_prompt(self, prompt:Union[str, List[str]], num_images_per_prompt:int = 1, do_classifier_free_guidance:bool = True, negative_prompt:Union[str, List[str]] = None):
277
+ """
278
+ Encodes the prompt into text encoder hidden states.
279
+
280
+ Parameters:
281
+ prompt (str or list(str)): prompt to be encoded
282
+ num_images_per_prompt (int): number of images that should be generated per prompt
283
+ do_classifier_free_guidance (bool): whether to use classifier free guidance or not
284
+ negative_prompt (str or list(str)): negative prompt to be encoded
285
+ Returns:
286
+ text_embeddings (np.ndarray): text encoder hidden states
287
+ """
288
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
289
+
290
+ # tokenize input prompts
291
+ text_inputs = self.tokenizer(
292
+ prompt,
293
+ padding="max_length",
294
+ max_length=self.tokenizer.model_max_length,
295
+ truncation=True,
296
+ return_tensors="np",
297
+ )
298
+ text_input_ids = text_inputs.input_ids
299
+
300
+ text_embeddings = self.text_encoder(
301
+ text_input_ids)[self._text_encoder_output]
302
+
303
+ # duplicate text embeddings for each generation per prompt
304
+ if num_images_per_prompt != 1:
305
+ bs_embed, seq_len, _ = text_embeddings.shape
306
+ text_embeddings = np.tile(
307
+ text_embeddings, (1, num_images_per_prompt, 1))
308
+ text_embeddings = np.reshape(
309
+ text_embeddings, (bs_embed * num_images_per_prompt, seq_len, -1))
310
+
311
+ # get unconditional embeddings for classifier free guidance
312
+ if do_classifier_free_guidance:
313
+ uncond_tokens: List[str]
314
+ max_length = text_input_ids.shape[-1]
315
+ if negative_prompt is None:
316
+ uncond_tokens = [""] * batch_size
317
+ elif isinstance(negative_prompt, str):
318
+ uncond_tokens = [negative_prompt]
319
+ else:
320
+ uncond_tokens = negative_prompt
321
+ uncond_input = self.tokenizer(
322
+ uncond_tokens,
323
+ padding="max_length",
324
+ max_length=max_length,
325
+ truncation=True,
326
+ return_tensors="np",
327
+ )
328
+
329
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids)[self._text_encoder_output]
330
+
331
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
332
+ seq_len = uncond_embeddings.shape[1]
333
+ uncond_embeddings = np.tile(uncond_embeddings, (1, num_images_per_prompt, 1))
334
+ uncond_embeddings = np.reshape(uncond_embeddings, (batch_size * num_images_per_prompt, seq_len, -1))
335
+
336
+ # For classifier free guidance, we need to do two forward passes.
337
+ # Here we concatenate the unconditional and text embeddings into a single batch
338
+ # to avoid doing two forward passes
339
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
340
+
341
+ return text_embeddings
342
+
343
+
344
+ def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None):
345
+ """
346
+ Function for getting initial latents for starting generation
347
+
348
+ Parameters:
349
+ image (PIL.Image.Image, *optional*, None):
350
+ Input image for generation, if not provided randon noise will be used as starting point
351
+ latent_timestep (torch.Tensor, *optional*, None):
352
+ Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
353
+ Returns:
354
+ latents (np.ndarray):
355
+ Image encoded in latent space
356
+ """
357
+ latents_shape = (1, 4, self.height // 8, self.width // 8)
358
+ noise = np.random.randn(*latents_shape).astype(np.float32)
359
+ if image is None:
360
+ # if you use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
361
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
362
+ noise = noise * self.scheduler.sigmas[0].numpy()
363
+ return noise, {}
364
+ input_image, meta = preprocess(image)
365
+ latents = self.vae_encoder(input_image)[self._vae_e_output] * 0.18215
366
+ latents = self.scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
367
+ return latents, meta
368
+
369
+ def postprocess_image(self, image:np.ndarray, meta:Dict, output_type:str = "pil"):
370
+ """
371
+ Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required),
372
+ normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
373
+
374
+ Parameters:
375
+ image (np.ndarray):
376
+ Generated image
377
+ meta (Dict):
378
+ Metadata obtained on latents preparing step, can be empty
379
+ output_type (str, *optional*, pil):
380
+ Output format for result, can be pil or numpy
381
+ Returns:
382
+ image (List of np.ndarray or PIL.Image.Image):
383
+ Postprocessed images
384
+ """
385
+ if "padding" in meta:
386
+ pad = meta["padding"]
387
+ (_, end_h), (_, end_w) = pad[1:3]
388
+ h, w = image.shape[2:]
389
+ unpad_h = h - end_h
390
+ unpad_w = w - end_w
391
+ image = image[:, :, :unpad_h, :unpad_w]
392
+ image = np.clip(image / 2 + 0.5, 0, 1)
393
+ image = np.transpose(image, (0, 2, 3, 1))
394
+ # 9. Convert to PIL
395
+ if output_type == "pil":
396
+ image = self.numpy_to_pil(image)
397
+ if "src_height" in meta:
398
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
399
+ image = [img.resize((orig_width, orig_height),
400
+ PIL.Image.Resampling.LANCZOS) for img in image]
401
+ else:
402
+ if "src_height" in meta:
403
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
404
+ image = [cv2.resize(img, (orig_width, orig_width))
405
+ for img in image]
406
+ return image
407
+
408
+ def get_timesteps(self, num_inference_steps:int, strength:float):
409
+ """
410
+ Helper function for getting scheduler timesteps for generation
411
+ In case of image-to-image generation, it updates number of steps according to strength
412
+
413
+ Parameters:
414
+ num_inference_steps (int):
415
+ number of inference steps for generation
416
+ strength (float):
417
+ value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
418
+ Values that approach 1.0 enable lots of variations but will also produce images that are not semantically consistent with the input.
419
+ """
420
+ # get the original timestep using init_timestep
421
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
422
+
423
+ t_start = max(num_inference_steps - init_timestep, 0)
424
+ timesteps = self.scheduler.timesteps[t_start:]
425
+
426
+ return timesteps, num_inference_steps - t_start
427
+
428
+ core=pv.Core()
429
+
430
+ text_enc=core.compile_model(text_encoder_path,device.value)
431
+
432
+ unet_model=core.compile_model(unet_path,device.value)
433
+ from transformers import CLIPTokenizer
434
+ from diffusers.schedulers import LMSDiscreteScheduler
435
+ lms=LMSDiscreteScheduler(
436
+ beta_start=0.00085,
437
+ beta_end=0.012,
438
+ beta_schedule="scaled_linear")
439
+ tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
440
+
441
+ pv_pipe=OVStableDiffusionPipeline(
442
+ tokenizer=tokenizer,
443
+ text_encoder=text_enc,
444
+ unet=unet_model,
445
+ vae_encoder=vae_encoder,
446
+ vae_decoder=vae_decoder,
447
+ scheduler=lms)
448
+
449
+ import gradio as gr
450
+
451
+ def generate_text(text,seed,num_steps,strength,_=gr.Progress(track_tqdm=True)):
452
+ result=pv_pipe(text,num_inference_steps=num_steps,seed=seed)
453
+ return result["sample"][0]
454
+ def generate_image(img,text,seed,num_steps,strength,_=gr.Progress(track_tqdm=True)):
455
+ result=pv_pipe(text,img,num_inference_steps=num_steps,seed=seed,strength=strength)
456
+ return result["sample"][0]
457
+
458
+ with gr.Blocks() as demo:
459
+ with gr.Tab("Zero-shot Text-to-Image Generation"):
460
+ with gr.Row():
461
+ with gr.Column():
462
+ text_input=gr.Textbox(lines=3,label="Text")
463
+ seed_input=gr.Slider(0,10000000,value=42,label="seed")
464
+ steps_input=gr.Slider(1,50,value=20,step=1,label="steps")
465
+ out=gr.Image(label="Result",type="pil")
466
+ btn=gr.Button()
467
+ btn.click(generate_text,[text_input,seed_input,steps_input],out)
468
+ gr.Examples([[sample_text,42,20]],[text_input,seed_input,steps_input])
469
+ try:
470
+ demo.queue().launch(debug=True)
471
+ except Exception:
472
+ demo.queue().launch(share=True,debug=True)