File size: 3,619 Bytes
75a53d9
 
 
 
 
 
 
a665e05
5f4a46b
 
75a53d9
 
 
 
 
 
 
 
 
 
 
 
 
 
178416a
75a53d9
 
 
 
 
 
398c0e8
 
 
 
75a53d9
 
 
1089b06
75a53d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f97cdd
 
 
 
 
 
 
5f4a46b
8f97cdd
75a53d9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import torch
import PIL
from PIL import Image
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import bitsandbytes
import accelerate
from my_model.captioner import captioning_config as config
from my_model.utilities import free_gpu_resources
    

class ImageCaptioningModel:
    def __init__(self):
        self.model_type = config.MODEL_TYPE
        self.processor = None
        self.model = None
        self.prompt = config.PROMPT
        self.max_image_size = config.MAX_IMAGE_SIZE
        self.min_length = config.MIN_LENGTH
        self.max_new_tokens = config.MAX_NEW_TOKENS
        self.model_path = config.MODEL_PATH
        self.device_map = config.DEVICE_MAP
        self.torch_dtype = config.TORCH_DTYPE
        self.load_in_8bit = config.LOAD_IN_8BIT
        self.load_in_4bit = config.LOAD_IN_4BIT
        self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE
        self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS



    def load_model(self):

        if self.load_in_4bit and self.load_in_8bit:  # check if in case both set to True by mistake.
            self.load_in_4bit = False
            
        if self.model_type == 'i_blip':
            self.processor = InstructBlipProcessor.from_pretrained(self.model_path,
                                                                   load_in_8bit=self.load_in_8bit,
                                                                   load_in_4bit=self.load_in_4bit,
                                                                   torch_dtype=self.torch_dtype,
                                                                   device_map=self.device_map
                                                                   )

            self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path,
                                                                              load_in_8bit=self.load_in_8bit,
                                                                              torch_dtype=self.torch_dtype,
                                                                              low_cpu_mem_usage=self.low_cpu_mem_usage,
                                                                              device_map=self.device_map
                                                                              )


    def resize_image(self, image, max_image_size=None):
        if max_image_size is None:
            max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
        h, w = image.size
        scale = max_image_size / max(h, w)

        if scale < 1:
            new_w = int(w * scale)
            new_h = int(h * scale)
            image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)

        return image


    def generate_caption(self, image_path):

        image = Image.open(image_path)
        image = self.resize_image(image)
        inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype)
        outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens)
        caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip()

        return caption

    def generate_captions_for_multiple_images(self, image_paths):

        return [self.generate_caption(image_path) for image_path in image_paths]
        

def get_caption(img):
    captioner = ImageCaptioningModel()
    captioner.load_model()
    caption = captioner.generate_caption(img)


    return caption


if __name__ == "__main__":
    pass