Spaces:
Runtime error
Runtime error
chats-bug
commited on
Commit
·
a95ba86
1
Parent(s):
1d4f82c
Added fine-tuning options
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
|
5 |
-
from model import
|
6 |
|
7 |
MODELS = {
|
8 |
"Git-Base-COCO": GitBaseCocoModel,
|
@@ -12,33 +12,38 @@ MODELS = {
|
|
12 |
def generate_captions(
|
13 |
image,
|
14 |
num_captions,
|
|
|
15 |
max_length,
|
16 |
temperature,
|
17 |
top_k,
|
18 |
top_p,
|
19 |
repetition_penalty,
|
20 |
diversity_penalty,
|
21 |
-
model_name,
|
22 |
):
|
23 |
"""
|
24 |
Generates captions for the given image.
|
25 |
-
|
26 |
-----
|
27 |
Parameters:
|
28 |
image: PIL.Image
|
29 |
The image to generate captions for.
|
30 |
-
max_len: int
|
31 |
-
The maximum length of the caption.
|
32 |
num_captions: int
|
33 |
The number of captions to generate.
|
34 |
-
|
35 |
-----
|
36 |
Returns:
|
37 |
list[str]
|
38 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
-
|
42 |
model = MODELS[model_name](device)
|
43 |
|
44 |
captions = model.generate(
|
@@ -56,32 +61,34 @@ def generate_captions(
|
|
56 |
captions = "\n".join(captions)
|
57 |
return captions
|
58 |
|
59 |
-
title = "
|
60 |
-
description = "
|
61 |
|
62 |
interface = gr.Interface(
|
63 |
fn=generate_captions,
|
64 |
inputs=[
|
65 |
-
gr.
|
66 |
-
gr.
|
67 |
-
gr.
|
68 |
-
gr.
|
69 |
-
gr.
|
70 |
-
gr.
|
71 |
-
gr.
|
72 |
-
gr.
|
73 |
-
gr.
|
74 |
],
|
75 |
outputs=[
|
76 |
-
gr.
|
77 |
],
|
78 |
title=title,
|
79 |
description=description,
|
80 |
-
|
|
|
81 |
|
82 |
|
83 |
if __name__ == "__main__":
|
|
|
84 |
interface.launch(
|
85 |
enable_queue=True,
|
86 |
-
debug=True
|
87 |
)
|
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
|
5 |
+
from model import BlipBaseModel, GitBaseCocoModel
|
6 |
|
7 |
MODELS = {
|
8 |
"Git-Base-COCO": GitBaseCocoModel,
|
|
|
12 |
def generate_captions(
|
13 |
image,
|
14 |
num_captions,
|
15 |
+
model_name,
|
16 |
max_length,
|
17 |
temperature,
|
18 |
top_k,
|
19 |
top_p,
|
20 |
repetition_penalty,
|
21 |
diversity_penalty,
|
|
|
22 |
):
|
23 |
"""
|
24 |
Generates captions for the given image.
|
25 |
+
|
26 |
-----
|
27 |
Parameters:
|
28 |
image: PIL.Image
|
29 |
The image to generate captions for.
|
|
|
|
|
30 |
num_captions: int
|
31 |
The number of captions to generate.
|
32 |
+
** Rest of the parameters are the same as in the model.generate method. **
|
33 |
-----
|
34 |
Returns:
|
35 |
list[str]
|
36 |
"""
|
37 |
+
# Convert the numerical values to their corresponding types.
|
38 |
+
# Gradio Slider returns values as floats: except when the value is a whole number, in which case it returns an int.
|
39 |
+
# Only float values suffer from this issue.
|
40 |
+
temperature = float(temperature)
|
41 |
+
top_p = float(top_p)
|
42 |
+
repetition_penalty = float(repetition_penalty)
|
43 |
+
diversity_penalty = float(diversity_penalty)
|
44 |
|
45 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
+
|
47 |
model = MODELS[model_name](device)
|
48 |
|
49 |
captions = model.generate(
|
|
|
61 |
captions = "\n".join(captions)
|
62 |
return captions
|
63 |
|
64 |
+
title = "AI tool for generating captions for images"
|
65 |
+
description = "This tool uses pretrained models to generate captions for images."
|
66 |
|
67 |
interface = gr.Interface(
|
68 |
fn=generate_captions,
|
69 |
inputs=[
|
70 |
+
gr.components.Image(type="pil", label="Image"),
|
71 |
+
gr.components.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Captions to Generate"),
|
72 |
+
gr.components.Dropdown(MODELS.keys(), label="Model", value=list(MODELS.keys())[1]), # Default to Blip Base
|
73 |
+
gr.components.Slider(minimum=20, maximum=100, step=5, value=50, label="Maximum Caption Length"),
|
74 |
+
gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Temperature"),
|
75 |
+
gr.components.Slider(minimum=1, maximum=100, step=1, value=50, label="Top K"),
|
76 |
+
gr.components.Slider(minimum=0.1, maximum=5.0, step=0.1, value=1.0, label="Top P"),
|
77 |
+
gr.components.Slider(minimum=1.0, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"),
|
78 |
+
gr.components.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Diversity Penalty"),
|
79 |
],
|
80 |
outputs=[
|
81 |
+
gr.components.Textbox(label="Caption"),
|
82 |
],
|
83 |
title=title,
|
84 |
description=description,
|
85 |
+
allow_flagging="never",
|
86 |
+
)
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
+
# Launch the interface.
|
91 |
interface.launch(
|
92 |
enable_queue=True,
|
93 |
+
debug=True,
|
94 |
)
|
model.py
CHANGED
@@ -7,26 +7,41 @@ class ImageCaptionModel:
|
|
7 |
processor,
|
8 |
model,
|
9 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
self.device = device
|
11 |
self.processor = processor
|
12 |
self.model = model
|
13 |
self.model.to(self.device)
|
14 |
-
|
15 |
def generate(
|
16 |
self,
|
17 |
image,
|
18 |
-
num_captions=1,
|
19 |
-
max_length=50,
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
diversity_penalty=0.0,
|
26 |
):
|
27 |
"""
|
28 |
Generates captions for the given image.
|
29 |
-
|
30 |
-----
|
31 |
Parameters:
|
32 |
preprocessor: transformers.PreTrainedTokenizerFast
|
@@ -37,8 +52,6 @@ class ImageCaptionModel:
|
|
37 |
The image to generate captions for.
|
38 |
num_captions: int
|
39 |
The number of captions to generate.
|
40 |
-
num_beam_groups: int
|
41 |
-
The number of beam groups to use for beam search in order to maintain diversity. Must be between 1 and num_beams. 1 means no group_beam_search..
|
42 |
temperature: float
|
43 |
The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
|
44 |
top_k: int
|
@@ -49,25 +62,45 @@ class ImageCaptionModel:
|
|
49 |
The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
|
50 |
diversity_penalty: float
|
51 |
The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
|
52 |
-
|
53 |
"""
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
temperature=temperature,
|
63 |
-
top_k=top_k,
|
64 |
-
top_p=top_p,
|
65 |
-
repetition_penalty=repetition_penalty,
|
66 |
-
diversity_penalty=diversity_penalty,
|
67 |
-
)
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
70 |
-
generated_caption = [generated_caption[i] for i in range(0, num_captions*2, 2)]
|
71 |
|
72 |
return generated_caption
|
73 |
|
@@ -79,8 +112,8 @@ class GitBaseCocoModel(ImageCaptionModel):
|
|
79 |
|
80 |
-----
|
81 |
Parameters:
|
82 |
-
device:
|
83 |
-
The device to run the model on.
|
84 |
checkpoint: str
|
85 |
The checkpoint to load the model from.
|
86 |
|
@@ -93,42 +126,24 @@ class GitBaseCocoModel(ImageCaptionModel):
|
|
93 |
model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
94 |
super().__init__(device, processor, model)
|
95 |
|
96 |
-
def generate(self, image, max_length=50, num_captions=1, **kwargs):
|
97 |
-
"""
|
98 |
-
Generates captions for the given image.
|
99 |
-
|
100 |
-
-----
|
101 |
-
Parameters:
|
102 |
-
image: PIL.Image
|
103 |
-
The image to generate captions for.
|
104 |
-
max_len: int
|
105 |
-
The maximum length of the caption.
|
106 |
-
num_captions: int
|
107 |
-
The number of captions to generate.
|
108 |
-
"""
|
109 |
-
captions = super().generate(image, max_length, num_captions, **kwargs)
|
110 |
-
return captions
|
111 |
-
|
112 |
|
113 |
class BlipBaseModel(ImageCaptionModel):
|
114 |
def __init__(self, device):
|
115 |
-
self.checkpoint = "Salesforce/blip-image-captioning-base"
|
116 |
-
processor = AutoProcessor.from_pretrained(self.checkpoint)
|
117 |
-
model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
|
118 |
-
super().__init__(device, processor, model)
|
119 |
-
|
120 |
-
def generate(self, image, max_length=50, num_captions=1, **kwargs):
|
121 |
"""
|
122 |
-
|
123 |
|
124 |
-----
|
125 |
Parameters:
|
126 |
-
|
127 |
-
The
|
128 |
-
|
129 |
-
The
|
130 |
-
|
131 |
-
|
|
|
|
|
132 |
"""
|
133 |
-
|
134 |
-
|
|
|
|
|
|
7 |
processor,
|
8 |
model,
|
9 |
) -> None:
|
10 |
+
"""
|
11 |
+
Initializes the model for generating captions for images.
|
12 |
+
|
13 |
+
-----
|
14 |
+
Parameters:
|
15 |
+
device: str
|
16 |
+
The device to use for the model. Must be either "cpu" or "cuda".
|
17 |
+
processor: transformers.AutoProcessor
|
18 |
+
The preprocessor to use for the model.
|
19 |
+
model: transformers.AutoModelForCausalLM or transformers.BlipForConditionalGeneration
|
20 |
+
The model to use for generating captions.
|
21 |
+
|
22 |
+
-----
|
23 |
+
Returns:
|
24 |
+
None
|
25 |
+
"""
|
26 |
self.device = device
|
27 |
self.processor = processor
|
28 |
self.model = model
|
29 |
self.model.to(self.device)
|
30 |
+
|
31 |
def generate(
|
32 |
self,
|
33 |
image,
|
34 |
+
num_captions: int = 1,
|
35 |
+
max_length: int = 50,
|
36 |
+
temperature: float = 1.0,
|
37 |
+
top_k: int = 50,
|
38 |
+
top_p: float = 1.0,
|
39 |
+
repetition_penalty: float = 1.0,
|
40 |
+
diversity_penalty: float = 0.0,
|
|
|
41 |
):
|
42 |
"""
|
43 |
Generates captions for the given image.
|
44 |
+
|
45 |
-----
|
46 |
Parameters:
|
47 |
preprocessor: transformers.PreTrainedTokenizerFast
|
|
|
52 |
The image to generate captions for.
|
53 |
num_captions: int
|
54 |
The number of captions to generate.
|
|
|
|
|
55 |
temperature: float
|
56 |
The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
|
57 |
top_k: int
|
|
|
62 |
The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
|
63 |
diversity_penalty: float
|
64 |
The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
|
65 |
+
|
66 |
"""
|
67 |
+
# Type checking and making sure the values are valid.
|
68 |
+
assert type(num_captions) == int and num_captions > 0, "num_captions must be a positive integer."
|
69 |
+
assert type(max_length) == int and max_length > 0, "max_length must be a positive integer."
|
70 |
+
assert type(temperature) == float and temperature > 0.0, "temperature must be a positive float."
|
71 |
+
assert type(top_k) == int and top_k > 0, "top_k must be a positive integer."
|
72 |
+
assert type(top_p) == float and top_p > 0.0, "top_p must be a positive float."
|
73 |
+
assert type(repetition_penalty) == float and repetition_penalty >= 1.0, "repetition_penalty must be a positive float greater than or equal to 1."
|
74 |
+
assert type(diversity_penalty) == float and diversity_penalty >= 0.0, "diversity_penalty must be a non negative float."
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) # Convert the image to pixel values.
|
77 |
+
|
78 |
+
# Generate captions ids.
|
79 |
+
if num_captions == 1:
|
80 |
+
generated_ids = self.model.generate(
|
81 |
+
pixel_values=pixel_values,
|
82 |
+
max_length=max_length,
|
83 |
+
num_return_sequences=1,
|
84 |
+
temperature=temperature,
|
85 |
+
top_k=top_k,
|
86 |
+
top_p=top_p,
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
generated_ids = self.model.generate(
|
90 |
+
pixel_values=pixel_values,
|
91 |
+
max_length=max_length,
|
92 |
+
num_beams=num_captions, # num_beams must be greater than or equal to num_captions and must be divisible by num_beam_groups.
|
93 |
+
num_beam_groups=num_captions, # num_beam_groups is set to equal to num_captions so that all the captions are diverse
|
94 |
+
num_return_sequences=num_captions, # generate multiple captions which are very similar to each other due to the grouping effect of beam search.
|
95 |
+
temperature=temperature,
|
96 |
+
top_k=top_k,
|
97 |
+
top_p=top_p,
|
98 |
+
repetition_penalty=repetition_penalty,
|
99 |
+
diversity_penalty=diversity_penalty,
|
100 |
+
)
|
101 |
+
|
102 |
+
# Decode the generated ids to get the captions.
|
103 |
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
|
104 |
|
105 |
return generated_caption
|
106 |
|
|
|
112 |
|
113 |
-----
|
114 |
Parameters:
|
115 |
+
device: str
|
116 |
+
The device to run the model on, either "cpu" or "cuda".
|
117 |
checkpoint: str
|
118 |
The checkpoint to load the model from.
|
119 |
|
|
|
126 |
model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
127 |
super().__init__(device, processor, model)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
class BlipBaseModel(ImageCaptionModel):
|
131 |
def __init__(self, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
"""
|
133 |
+
A wrapper class for the Blip-Base model. It is a pretrained model for image captioning.
|
134 |
|
135 |
-----
|
136 |
Parameters:
|
137 |
+
device: str
|
138 |
+
The device to run the model on, either "cpu" or "cuda".
|
139 |
+
checkpoint: str
|
140 |
+
The checkpoint to load the model from.
|
141 |
+
|
142 |
+
-----
|
143 |
+
Returns:
|
144 |
+
None
|
145 |
"""
|
146 |
+
self.checkpoint = "Salesforce/blip-image-captioning-base"
|
147 |
+
processor = AutoProcessor.from_pretrained(self.checkpoint)
|
148 |
+
model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
|
149 |
+
super().__init__(device, processor, model)
|