Diffusers
Safetensors
StableDiffusionUpscalePipeline
stable-diffusion
yanis9351's picture
Upload handler.py
9b535c8 verified
raw
history blame
1.5 kB
from typing import Dict, List, Any
from diffusers import StableDiffusionUpscalePipeline
import torch
from PIL import Image
import io
class EndpointHandler:
def __init__(self, path: str):
# Load the Stable Diffusion x4 upscaler model
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
path,
torch_dtype=torch.float16
)
self.pipeline.to("cuda")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs: str - The text prompt for the upscaling.
image: bytes - The low-resolution image as byte data.
Return:
A list of dictionaries with the upscaled image.
"""
# Extract inputs and image from the payload
prompt = data.get("inputs", "")
image_bytes = data.get("image", None)
if image_bytes is None:
return [{"error": "No image provided"}]
# Convert the byte data to an image
low_res_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Perform upscaling
upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0]
# Save the upscaled image to a byte stream
byte_io = io.BytesIO()
upscaled_image.save(byte_io, format="PNG")
byte_io.seek(0)
# Return the upscaled image as byte data
return [{"upscaled_image": byte_io.getvalue()}]