import PIL.Image import torch, gc from controlnet_aux_local import NormalBaeDetector#, CannyDetector class Preprocessor: MODEL_ID = "lllyasviel/Annotators" def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return elif name == "NormalBae": print("Loading NormalBae") self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda") torch.cuda.empty_cache() self.name = name else: raise ValueError return def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: return self.model(image, **kwargs)