m7mdal7aj commited on
Commit
75a53d9
·
verified ·
1 Parent(s): 311b9c4

adding captioning folder and files

Browse files
my_model/captioner/captioning_config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Configuration parameters
4
+ MODEL_TYPE = "i_blip"
5
+ PROMPT = "describe this image in details"
6
+ MAX_IMAGE_SIZE = 1024
7
+ MIN_LENGTH = 20
8
+ MAX_NEW_TOKENS = 100
9
+ MODEL_PATH = "Salesforce/instructblip-vicuna-7b"
10
+ LOAD_IN_8BIT = True
11
+ TORCH_DTYPE = torch.float16
12
+ DEVICE_MAP = "auto"
13
+ LOW_CPU_MEM_USAGE = True
14
+ SKIP_SPECIAL_TOKENS = True
my_model/captioner/image_captioning.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import PIL
4
+ from PIL import Image
5
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
6
+ import bitsandbytes
7
+ import accelerate
8
+ import captioning_config as config
9
+
10
+
11
+ class ImageCaptioningModel:
12
+ def __init__(self):
13
+ self.model_type = config.MODEL_TYPE
14
+ self.processor = None
15
+ self.model = None
16
+ self.prompt = config.PROMPT
17
+ self.max_image_size = config.MAX_IMAGE_SIZE
18
+ self.min_length = config.MIN_LENGTH
19
+ self.max_new_tokens = config.MAX_NEW_TOKENS
20
+ self.model_path = config.MODEL_PATH
21
+ self.device_map = config.DEVICE_MAP
22
+ self.torch_dtype = config.TORCH_DTYPE
23
+ self.load_in_8bit = config.LOAD_IN_8BIT
24
+ self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE
25
+ self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS
26
+
27
+
28
+
29
+ def load_model(self):
30
+ if self.model_type == 'i_blip':
31
+ self.processor = InstructBlipProcessor.from_pretrained(self.model_path,
32
+ load_in_8bit=self.load_in_8bit,
33
+ torch_dtype=self.torch_dtype,
34
+ device_map=self.device_map
35
+ )
36
+
37
+ self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path,
38
+ load_in_8bit=self.load_in_8bit,
39
+ torch_dtype=self.torch_dtype,
40
+ low_cpu_mem_usage=self.low_cpu_mem_usage,
41
+ device_map=self.device_map
42
+ )
43
+
44
+
45
+ def resize_image(self, image, max_image_size=None):
46
+ if max_image_size is None:
47
+ max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
48
+ h, w = image.size
49
+ scale = max_image_size / max(h, w)
50
+
51
+ if scale < 1:
52
+ new_w = int(w * scale)
53
+ new_h = int(h * scale)
54
+ image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)
55
+
56
+ return image
57
+
58
+
59
+ def generate_caption(self, image_path):
60
+
61
+ image = Image.open(image_path)
62
+ image = self.resize_image(image)
63
+ inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype)
64
+ outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens)
65
+ caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip()
66
+
67
+ return caption
68
+
69
+ def generate_captions_for_multiple_images(self, image_paths):
70
+
71
+ return [self.generate_caption(image_path) for image_path in image_paths]
72
+
73
+
74
+ if __name__ == "__main__":
75
+ pass