|
from typing import Dict, List, Any
|
|
from diffusers import StableDiffusionUpscalePipeline
|
|
import torch
|
|
from PIL import Image
|
|
import io
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path: str):
|
|
|
|
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
|
path,
|
|
torch_dtype=torch.float16
|
|
)
|
|
self.pipeline.to("cuda")
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""
|
|
data args:
|
|
inputs: str - The text prompt for the upscaling.
|
|
image: bytes - The low-resolution image as byte data.
|
|
|
|
Return:
|
|
A list of dictionaries with the upscaled image.
|
|
"""
|
|
|
|
prompt = data.get("inputs", "")
|
|
image_bytes = data.get("image", None)
|
|
|
|
if image_bytes is None:
|
|
return [{"error": "No image provided"}]
|
|
|
|
|
|
low_res_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
|
|
upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0]
|
|
|
|
|
|
byte_io = io.BytesIO()
|
|
upscaled_image.save(byte_io, format="PNG")
|
|
byte_io.seek(0)
|
|
|
|
|
|
return [{"upscaled_image": byte_io.getvalue()}]
|
|
|