API_SDXLLightning / handler.py
UAI-Software's picture
Upload folder using huggingface_hub
2c86436 verified
import os
from typing import Dict, List, Any
import sys
rootDir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(rootDir)
from imageRequest import ImageRequest
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.pipe = None
self.modelName = ""
baseReq = ImageRequest()
baseReq.model = "SG161222/RealVisXL_V4.0"
self.LoadModel(baseReq)
def LoadModel(self, request):
base = "SG161222/RealVisXL_V4.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_8step_unet.safetensors" # Use the correct ckpt for your step setting!
if request.model == "default":
request.model = base
else:
base = request.model
if self.pipe is None or self.modelName != request.model:
# Load model.
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
self.pipe = pipe
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
input (:obj: `str` | `PIL.Image` | `np.array`)
seed (:obj: `int`)
prompt (:obj: `str`)
negative_prompt (:obj: `str`)
steps (:obj: `int`)
guidance_scale (:obj: `float`)
width (:obj: `int`)
height (:obj: `int`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
request = ImageRequest.FromDict(inputs)
response = self.__runProcess__(request)
return response
def ImageToBase64(self, image):
import io
import base64
from PIL import Image
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def __runProcess__(self, request: ImageRequest) -> List[Dict[str, Any]]:
"""
Run SDXL Lightning pipeline
"""
import torch
self.LoadModel(request)
# Ensure using the same inference steps as the loaded model and CFG set to 0.
images = self.pipe(request.prompt, negative_prompt = request.negative_prompt, num_inference_steps=request.steps, guidance_scale=0).images
return {"media":[{"media":self.ImageToBase64(img)} for img in images]}