File size: 19,683 Bytes
1abc402
 
0791522
1abc402
51cf362
1abc402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51cf362
ce336c7
 
 
ce1acc2
ce336c7
 
 
 
1abc402
 
51cf362
 
 
 
1abc402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51cf362
 
 
 
 
 
 
1abc402
5359b48
d19d3af
0791522
1abc402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8205ae3
97ba3f0
8205ae3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
from diffusers import StableDiffusionPipeline
import gc
import gradio as gr

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cpu")
text_encoder = pipe.text_encoder
text_encoder.eval()
unet = pipe.unet
unet.eval()
vae = pipe.vae
vae.eval()

del pipe
gc.collect()

from pathlib import Path
import torch
import openvino as pv

text_encoder_path=Path("text_encoder.xml")

def cleanup_cache():
    torch._C._jit_clear_class_registry()
    torch.jit._recursive.concrete_type_store=torch.jit._recursive.ConcreteTypeStore()
    torch.jit._state._clear_class_state()
    
def convert_encoder(text_encoder:torch.nn.Module,ir_path:Path):
        """
    Convert Text Encoder mode. 
    Function accepts text encoder model, and prepares example inputs for conversion, 
    Parameters: 
        text_encoder (torch.nn.Module): text_encoder model from Stable Diffusion pipeline
        ir_path (Path): File for storing model
    Returns:
        None
    """
        
        input_ids=torch.ones((1,77),dtype=torch.long)
        text_encoder.eval()
        
        with torch.no_grad():
            ov_model=pv.convert_model(text_encoder,example_input=input_ids,input=[(1,77),])
        pv.save_model(ov_model,ir_path)
        del ov_model
        cleanup_cache()
        print(f"Text Encoder successfully converted to TR and saved to {ir_path}")
if not text_encoder_path.exists():
    convert_encoder(text_encoder,text_encoder_path)
else:
    print(f"Text encoder will be loaded from {text_encoder_path}")
del text_encoder
gc.collect()

import numpy as np
unet_path=Path("unet.xml")
dtype_mapping={
    torch.float32: pv.Type.f32,
    torch.float64: pv.Type.f64
}

def convert_unet(unet:torch.nn.Module,ir_path:Path):
        """
    Convert U-net model to IR format. 
    Function accepts unet model, prepares example inputs for conversion, 
    Parameters: 
        unet (StableDiffusionPipeline): unet from Stable Diffusion pipeline
        ir_path (Path): File for storing model
    Returns:
        None
    """

        encoder_hidden_state=torch.ones((2,77,768))
        latents_shape=(2,4,512 // 8,512 // 8)
        latents=torch.randn(latents_shape)
        t=torch.from_numpy(np.array(1,dtype=float))
        dummy_inputs=(latents,t,encoder_hidden_state)
        input_info=[]
        for input_tensor in dummy_inputs:
            shape=pv.PartialShape(tuple(input_tensor.shape))
            element_type=dtype_mapping[input_tensor.dtype]
            input_info.append((shape,element_type))
            
        unet.eval()
        with torch.no_grad():
            pv_model=pv.convert_model(unet,example_input=dummy_inputs,input=input_info)
        pv.save_model(pv_model,ir_path)
        del pv_model
        cleanup_cache()
        print(f"Unet successfully converted to IR and saved to {ir_path}")
        
if not unet_path.exists():
    convert_unet(unet,unet_path)
    gc.collect()
else:
    print(f"unet will be loaded from {unet_path}")
del unet
gc.collect()

VAE_ENCODER_PATH = Path("vae_encoder.xml")

def convert_vae_encoder(vae: torch.nn.Module, ir_path: Path):
    class VAEEncoder(torch.nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae

        def forward(self, image):
            return self.vae.encode(x=image)["latent_dist"].sample()
    vae_encoder = VAEEncoder(vae)
    vae_encoder.eval()
    image = torch.zeros((1, 3, 512, 512))
    with torch.no_grad():
        ov_model = pv.convert_model(vae_encoder, example_input=image, input=[((1,3,512,512),)])
    pv.save_model(ov_model, ir_path)
    del ov_model
    cleanup_cache()
    print(f'VAE encoder successfully converted to IR and saved to {ir_path}')


if not VAE_ENCODER_PATH.exists():
    convert_vae_encoder(vae, VAE_ENCODER_PATH)
else:
    print(f"VAE encoder will be loaded from {VAE_ENCODER_PATH}")

VAE_DECODER_PATH = Path('vae_decoder.xml')

def convert_vae_decoder(vae: torch.nn.Module, ir_path: Path):
    class VAEDecoder(torch.nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae

        def forward(self, latents):
            return self.vae.decode(latents)
    
    vae_decoder = VAEDecoder(vae)
    latents = torch.zeros((1, 4, 64, 64))

    vae_decoder.eval()
    with torch.no_grad():
        ov_model = pv.convert_model(vae_decoder, example_input=latents, input=[((1,4,64,64),)])
    pv.save_model(ov_model, ir_path)
    del ov_model
    cleanup_cache()
    print(f'VAE decoder successfully converted to IR and saved to {ir_path}')


if not VAE_DECODER_PATH.exists():
    convert_vae_decoder(vae, VAE_DECODER_PATH)
else:
    print(f"VAE decoder will be loaded from {VAE_DECODER_PATH}")

del vae
gc.collect()


import inspect
from typing import List,Optional,Union,Dict
import PIL
import cv2

from transformers import CLIPTokenizer
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import DDIMScheduler,LMSDiscreteScheduler,PNDMScheduler
from openvino.runtime import Model

def scale_window(dst_width:int,dst_height:int,image_width:int,image_height:int):
    im_scale=min(dst_height / image_height,dst_width / image_width)
    return int(im_scale * image_width), int(im_scale * image_height)
def preprocess(image:PIL.Image.Image):
    src_width,src_height=image.size
    dst_width,dst_height=scale_window(512,512,src_width,src_height)
    image=np.array(image.resize((dst_width,dst_height),resample=PIL.Image.Resampling.LANCZOS))[None,:]
    pad_width=512-dst_width
    pad_height=512-dst_height
    pad=((0,0),(0,pad_height),(0,pad_width),(0,0))
    image=np.pad(image,pad,mode="constant")
    image=image.astype(np.float32) / 255.0
    image=2.0* image - 1.0
    image=image.transpose(0,3,1,2)
    return image, {"padding":pad,"src_width":src_width,"src_height":src_height}

class OVStableDiffusionPipeline(DiffusionPipeline):
    def __init__(
        self,
        vae_decoder: Model,
        text_encoder: Model,
        tokenizer: CLIPTokenizer,
        unet: Model,
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
        vae_encoder: Model = None,
    ):
        super().__init__()
        self.scheduler = scheduler
        self.vae_decoder = vae_decoder
        self.vae_encoder = vae_encoder
        self.text_encoder = text_encoder
        self.unet = unet
        self._text_encoder_output = text_encoder.output(0)
        self._unet_output = unet.output(0)
        self._vae_d_output = vae_decoder.output(0)
        self._vae_e_output = vae_encoder.output(0) if vae_encoder is not None else None
        self.height = 512
        self.width = 512
        self.tokenizer = tokenizer
        
    def __call__(
        self,
        prompt: Union[str, List[str]],
        image: PIL.Image.Image = None,
        num_inference_steps: Optional[int] = 50,
        negative_prompt: Union[str, List[str]] = None,
        guidance_scale: Optional[float] = 7.5,
        eta: Optional[float] = 0.0,
        output_type: Optional[str] = "pil",
        seed: Optional[int] = None,
        strength: float = 1.0,
        gif: Optional[bool] = False,
        **kwargs,
    ):
        if seed is not None:
            np.random.seed(seed)

        img_buffer = []
        do_classifier_free_guidance = guidance_scale > 1.0
        # get prompt text embeddings
        text_embeddings = self._encode_prompt(prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt)
        
        # set timesteps
        accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
        extra_set_kwargs = {}
        if accepts_offset:
            extra_set_kwargs["offset"] = 1

        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
        latent_timestep = timesteps[:1]

        # get the initial random noise unless the user supplied it
        latents, meta = self.prepare_latents(image, latent_timestep)

        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        for i, t in enumerate(self.progress_bar(timesteps)):
            # expand the latents if you are doing classifier free guidance
            latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet([latent_model_input, t, text_embeddings])[self._unet_output]
            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
            if gif:
                image = self.vae_decoder(latents * (1 / 0.18215))[self._vae_d_output]
                image = self.postprocess_image(image, meta, output_type)
                img_buffer.extend(image)

        # scale and decode the image latents with vae
        image = self.vae_decoder(latents * (1 / 0.18215))[self._vae_d_output]

        image = self.postprocess_image(image, meta, output_type)
        return {"sample": image, 'iterations': img_buffer}
    
    def _encode_prompt(self, prompt:Union[str, List[str]], num_images_per_prompt:int = 1, do_classifier_free_guidance:bool = True, negative_prompt:Union[str, List[str]] = None):
        """
        Encodes the prompt into text encoder hidden states.

        Parameters:
            prompt (str or list(str)): prompt to be encoded
            num_images_per_prompt (int): number of images that should be generated per prompt
            do_classifier_free_guidance (bool): whether to use classifier free guidance or not
            negative_prompt (str or list(str)): negative prompt to be encoded
        Returns:
            text_embeddings (np.ndarray): text encoder hidden states
        """
        batch_size = len(prompt) if isinstance(prompt, list) else 1

        # tokenize input prompts
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="np",
        )
        text_input_ids = text_inputs.input_ids

        text_embeddings = self.text_encoder(
            text_input_ids)[self._text_encoder_output]

        # duplicate text embeddings for each generation per prompt
        if num_images_per_prompt != 1:
            bs_embed, seq_len, _ = text_embeddings.shape
            text_embeddings = np.tile(
                text_embeddings, (1, num_images_per_prompt, 1))
            text_embeddings = np.reshape(
                text_embeddings, (bs_embed * num_images_per_prompt, seq_len, -1))

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            max_length = text_input_ids.shape[-1]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            else:
                uncond_tokens = negative_prompt
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="np",
            )

            uncond_embeddings = self.text_encoder(uncond_input.input_ids)[self._text_encoder_output]

            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = uncond_embeddings.shape[1]
            uncond_embeddings = np.tile(uncond_embeddings, (1, num_images_per_prompt, 1))
            uncond_embeddings = np.reshape(uncond_embeddings, (batch_size * num_images_per_prompt, seq_len, -1))

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])

        return text_embeddings


    def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None):
        """
        Function for getting initial latents for starting generation
        
        Parameters:
            image (PIL.Image.Image, *optional*, None):
                Input image for generation, if not provided randon noise will be used as starting point
            latent_timestep (torch.Tensor, *optional*, None):
                Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
        Returns:
            latents (np.ndarray):
                Image encoded in latent space
        """
        latents_shape = (1, 4, self.height // 8, self.width // 8)
        noise = np.random.randn(*latents_shape).astype(np.float32)
        if image is None:
            # if you use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
            if isinstance(self.scheduler, LMSDiscreteScheduler):
                noise = noise * self.scheduler.sigmas[0].numpy()
                return noise, {}
        input_image, meta = preprocess(image)
        latents = self.vae_encoder(input_image)[self._vae_e_output] * 0.18215
        latents = self.scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
        return latents, meta

    def postprocess_image(self, image:np.ndarray, meta:Dict, output_type:str = "pil"):
        """
        Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required), 
        normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
        
        Parameters:
            image (np.ndarray):
                Generated image
            meta (Dict):
                Metadata obtained on latents preparing step, can be empty
            output_type (str, *optional*, pil):
                Output format for result, can be pil or numpy
        Returns:
            image (List of np.ndarray or PIL.Image.Image):
                Postprocessed images
        """
        if "padding" in meta:
            pad = meta["padding"]
            (_, end_h), (_, end_w) = pad[1:3]
            h, w = image.shape[2:]
            unpad_h = h - end_h
            unpad_w = w - end_w
            image = image[:, :, :unpad_h, :unpad_w]
        image = np.clip(image / 2 + 0.5, 0, 1)
        image = np.transpose(image, (0, 2, 3, 1))
        # 9. Convert to PIL
        if output_type == "pil":
            image = self.numpy_to_pil(image)
            if "src_height" in meta:
                orig_height, orig_width = meta["src_height"], meta["src_width"]
                image = [img.resize((orig_width, orig_height),
                                    PIL.Image.Resampling.LANCZOS) for img in image]
        else:
            if "src_height" in meta:
                orig_height, orig_width = meta["src_height"], meta["src_width"]
                image = [cv2.resize(img, (orig_width, orig_width))
                         for img in image]
        return image

    def get_timesteps(self, num_inference_steps:int, strength:float):
        """
        Helper function for getting scheduler timesteps for generation
        In case of image-to-image generation, it updates number of steps according to strength
        
        Parameters:
           num_inference_steps (int):
              number of inference steps for generation
           strength (float):
               value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. 
               Values that approach 1.0 enable lots of variations but will also produce images that are not semantically consistent with the input.
        """
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start 


core=pv.Core()
import ipywidgets as widgets
device=widgets.Dropdown(
    options=core.available_devices+["AUTO"],
    value="CPU",
    desciption="Device:",
    disabled=False,
)
device
text_enc=core.compile_model(text_encoder_path,device.value)
unet_model=core.compile_model(unet_path,device.value)

pv_config={"INFERENCE_PRECISION_HINT":"f32"}if device.value !="CPU" else {}
vae_decoder=core.compile_model(VAE_DECODER_PATH,device.value,pv_config)
vae_encoder=core.compile_model(VAE_ENCODER_PATH,device.value,pv_config)
from transformers import CLIPTokenizer
from diffusers.schedulers import LMSDiscreteScheduler
lms=LMSDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear")
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

pv_pipe=OVStableDiffusionPipeline(
    tokenizer=tokenizer,
    text_encoder=text_enc,
    unet=unet_model,
    vae_encoder=vae_encoder,
    vae_decoder=vae_decoder,
    scheduler=lms)
from ipywidgets import widgets
sample_text=("A Dog wearing golden rich mens necklace")
text_prompt=widgets.Text(value=sample_text,description="A Dog wearing golden rich mens necklace ")
num_steps=widgets.IntSlider(min=1,max=50,value=20,description="steps:")
seed=widgets.IntSlider(min=0,max=10000000,description="seed:",value=54)
widgets.VBox([text_prompt,seed,num_steps])
result=pv_pipe(text_prompt.value,num_inference_steps=num_steps.value,seed=seed.value)





def generate_text(text,seed,num_steps,strength,_=gr.Progress(track_tqdm=True)):
    result=pv_pipe(text,num_inference_steps=num_steps,seed=seed)
    return result["sample"][0]
def generate_image(img,text,seed,num_steps,strength,_=gr.Progress(track_tqdm=True)):
    result=pv_pipe(text,img,num_inference_steps=num_steps,seed=seed,strength=strength)
    return result["sample"][0]

with gr.Blocks() as demo:
    with gr.Tab("Zero-shot Text-to-Image Generation"):
        with gr.Row():
            with gr.Column():
                text_input=gr.Textbox(lines=3,label="Text")
                seed_input=gr.Slider(0,10000000,value=42,label="seed")
                steps_input=gr.Slider(1,50,value=20,step=1,label="steps")
            out=gr.Image(label="Result",type="pil")
        btn=gr.Button()
        btn.click(generate_text,[text_input,seed_input,steps_input],out)
        gr.Examples([[sample_text,42,20]],[text_input,seed_input,steps_input])
try:
    demo.queue().launch(debug=True)
except Exception:
    demo.queue().launch(share=True,debug=True)