# import numpy as np import PIL.Image # import torch from controlnet_aux import NormalBaeDetector#, CannyDetector # from controlnet_aux.util import HWC3 # import cv2 # from cv_utils import resize_image class Preprocessor: MODEL_ID = "lllyasviel/Annotators" # def resize_image(input_image, resolution, interpolation=None): # H, W, C = input_image.shape # H = float(H) # W = float(W) # k = float(resolution) / max(H, W) # H *= k # W *= k # H = int(np.round(H / 64.0)) * 64 # W = int(np.round(W / 64.0)) * 64 # if interpolation is None: # interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA # img = cv2.resize(input_image, (W, H), interpolation=interpolation) # return img def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return elif name == "NormalBae": print("Loading NormalBae") self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda") # elif name == "Canny": # self.model = CannyDetector() else: raise ValueError # torch.cuda.empty_cache() # gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # if self.name == "Canny": # if "detect_resolution" in kwargs: # detect_resolution = kwargs.pop("detect_resolution") # image = np.array(image) # image = HWC3(image) # image = resize_image(image, resolution=detect_resolution) # image = self.model(image, **kwargs) # return PIL.Image.fromarray(image) # elif self.name == "Midas": # detect_resolution = kwargs.pop("detect_resolution", 512) # image_resolution = kwargs.pop("image_resolution", 512) # image = np.array(image) # image = HWC3(image) # image = resize_image(image, resolution=detect_resolution) # image = self.model(image, **kwargs) # image = HWC3(image) # image = resize_image(image, resolution=image_resolution) # return PIL.Image.fromarray(image) # else: return self.model(image, **kwargs)