decomp-diffusion / upsampling.py
jsu27's picture
added upsampling
a6049d1
raw
history blame
1.07 kB
import torch as th
from diffusers import IFImg2ImgSuperResolutionPipeline
from transformers import T5EncoderModel
from PIL import Image
import numpy as np
def get_pipeline():
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
subfolder="text_encoder",
device_map="auto",
load_in_8bit=True,
variant="8bit"
)
pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0",
text_encoder=text_encoder,
variant="fp16",
torch_dtype=th.float16,
device_map="auto",
watermarker=None
)
return pipe
def upscale_image(im, pipe):
"""im is 64x64 PIL image"""
prompt = ''
prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
generator = th.Generator().manual_seed(0)
image = pipe(
image=original_image,
original_image=original_image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
).images[0]
return image