SDXL_NIJI_SIX / handler.py
selectmixer's picture
first commit
ea4334e
from typing import Dict, List, Any
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import numpy as np
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
model_id = "selectmixer/SDXL_NIJI_SIX"
self.pipe = StableDiffusionPipeline.from_pretrained(model_id)
self.pipe.to("cuda")
self.pipe.load_lora_weights("selectmixer/LORA_NIJI_DLCV6_SDXL", weight_name="SDXL_Niji_V6_DLC_LoRa_V2.safetensors", adapter_name="test")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# pseudo
# self.model(input)
# generator = torch.Generator("cuda").manual_seed(31)
# image = self.pipe("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
# return image
inputs = data.get("inputs")
kwargs = data.get("kwargs", {})
# Convert inputs to the required format
if isinstance(inputs, str):
# Assuming the string is a prompt for text-to-image generation
prompt = inputs
elif isinstance(inputs, Image.Image):
# Convert PIL.Image to a format suitable for the model
prompt = self.pipe.feature_extractor(images=inputs, return_tensors="pt").to("cuda")
elif isinstance(inputs, np.ndarray):
# Convert np.array to PIL.Image and then to the required format
image = Image.fromarray(inputs)
prompt = self.pipe.feature_extractor(images=image, return_tensors="pt").to("cuda")
else:
raise ValueError("Unsupported input type")
lora_scale = kwargs.get("lora_scale", 0.5)
# Generate images using the model
results = self.pipe(prompt, cross_attention_kwargs={"scale":lora_scale}, **kwargs)
# Convert results to a list of dictionaries
output = [{"generated_image": result} for result in results.images]
return output