from typing import Dict, List, Any from diffusers import StableDiffusionUpscalePipeline import torch from PIL import Image import io class EndpointHandler: def __init__(self, path: str): # Load the Stable Diffusion x4 upscaler model self.pipeline = StableDiffusionUpscalePipeline.from_pretrained( path, torch_dtype=torch.float16 ) self.pipeline.to("cuda") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs: str - The text prompt for the upscaling. image: bytes - The low-resolution image as byte data. Return: A list of dictionaries with the upscaled image. """ # Extract inputs and image from the payload prompt = data.get("inputs", "") image_bytes = data.get("image", None) if image_bytes is None: return [{"error": "No image provided"}] # Convert the byte data to an image low_res_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Perform upscaling upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0] # Save the upscaled image to a byte stream byte_io = io.BytesIO() upscaled_image.save(byte_io, format="PNG") byte_io.seek(0) # Return the upscaled image as byte data return [{"upscaled_image": byte_io.getvalue()}]