import os import torch import PIL from PIL import Image from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration import bitsandbytes import accelerate from my_model.captioner import captioning_config as config from my_model.utilities import free_gpu_resources class ImageCaptioningModel: def __init__(self): self.model_type = config.MODEL_TYPE self.processor = None self.model = None self.prompt = config.PROMPT self.max_image_size = config.MAX_IMAGE_SIZE self.min_length = config.MIN_LENGTH self.max_new_tokens = config.MAX_NEW_TOKENS self.model_path = config.MODEL_PATH self.device_map = config.DEVICE_MAP self.torch_dtype = config.TORCH_DTYPE self.load_in_8bit = config.LOAD_IN_8BIT self.load_in_4bit = config.LOAD_IN_4BIT self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS def load_model(self): if self.load_in_4bit and self.load_in_8bit: # check if in case both set to True by mistake. self.load_in_4bit = False if self.model_type == 'i_blip': self.processor = InstructBlipProcessor.from_pretrained(self.model_path, load_in_8bit=self.load_in_8bit, load_in_4bit=self.load_in_4bit, torch_dtype=self.torch_dtype, device_map=self.device_map ) self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path, load_in_8bit=self.load_in_8bit, torch_dtype=self.torch_dtype, low_cpu_mem_usage=self.low_cpu_mem_usage, device_map=self.device_map ) def resize_image(self, image, max_image_size=None): if max_image_size is None: max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024")) h, w = image.size scale = max_image_size / max(h, w) if scale < 1: new_w = int(w * scale) new_h = int(h * scale) image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) return image def generate_caption(self, image_path): image = Image.open(image_path) image = self.resize_image(image) inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype) outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens) caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip() return caption def generate_captions_for_multiple_images(self, image_paths): return [self.generate_caption(image_path) for image_path in image_paths] def get_caption(img): captioner = ImageCaptioningModel() captioner.load_model() caption = captioner.generate_caption(img) return caption if __name__ == "__main__": pass