Upload 6 files
Browse files- .gitattributes +3 -35
- .gitignore +3 -0
- App.py +9 -0
- Model.py +279 -0
- Predict.py +35 -0
- last_checkpoint.pt +3 -0
.gitattributes
CHANGED
@@ -1,35 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
Checkpoint.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
/Main.ipynb/
|
2 |
+
/Dataset/
|
3 |
+
/Checkpoint/
|
App.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from Predict import generate_caption
|
3 |
+
|
4 |
+
interface = gr.Interface(
|
5 |
+
fn = generate_caption,
|
6 |
+
inputs =[gr.components.Image(), gr.components.Textbox(label = "Question")],
|
7 |
+
outputs=[gr.components.Textbox(label = "Answer", lines=3)]
|
8 |
+
)
|
9 |
+
interface.launch(share = True, debug = True)
|
Model.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoTokenizer, AutoModel, AutoFeatureExtractor
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
import warnings
|
8 |
+
warnings.filterwarnings("ignore")
|
9 |
+
|
10 |
+
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
|
12 |
+
vision_model_name = "google/vit-base-patch16-224-in21k"
|
13 |
+
language_model_name = "vinai/phobert-base"
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def generate_padding_mask(sequences, padding_idx):
|
18 |
+
if sequences is None:
|
19 |
+
return None
|
20 |
+
if len(sequences.shape) == 2:
|
21 |
+
__seq = sequences.unsqueeze(dim=-1)
|
22 |
+
else:
|
23 |
+
__seq = sequences
|
24 |
+
mask = (torch.sum(__seq, dim=-1) == (padding_idx*__seq.shape[-1])).long() * -10e4
|
25 |
+
return mask.unsqueeze(1).unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
class ScaledDotProduct(nn.Module):
|
29 |
+
def __init__(self, d_model = 512, h = 8, d_k = 64, d_v = 64):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.fc_q = nn.Linear(d_model, h * d_k)
|
33 |
+
self.fc_k = nn.Linear(d_model, h * d_k)
|
34 |
+
self.fc_v = nn.Linear(d_model, h * d_v)
|
35 |
+
self.fc_o = nn.Linear(h * d_v, d_model)
|
36 |
+
|
37 |
+
self.d_model = d_model
|
38 |
+
self.d_k = d_k
|
39 |
+
self.d_v = d_v
|
40 |
+
self.h = h
|
41 |
+
|
42 |
+
self.init_weights()
|
43 |
+
|
44 |
+
def init_weights(self):
|
45 |
+
nn.init.xavier_uniform_(self.fc_q.weight)
|
46 |
+
nn.init.xavier_uniform_(self.fc_k.weight)
|
47 |
+
nn.init.xavier_uniform_(self.fc_v.weight)
|
48 |
+
nn.init.xavier_uniform_(self.fc_o.weight)
|
49 |
+
nn.init.constant_(self.fc_q.bias, 0)
|
50 |
+
nn.init.constant_(self.fc_k.bias, 0)
|
51 |
+
nn.init.constant_(self.fc_v.bias, 0)
|
52 |
+
nn.init.constant_(self.fc_o.bias, 0)
|
53 |
+
|
54 |
+
def forward(self, queries, keys, values, attention_mask=None, **kwargs):
|
55 |
+
b_s, nq = queries.shape[:2]
|
56 |
+
nk = keys.shape[1]
|
57 |
+
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
|
58 |
+
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
|
59 |
+
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
|
60 |
+
|
61 |
+
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
|
62 |
+
if attention_mask is not None:
|
63 |
+
att += attention_mask
|
64 |
+
att = torch.softmax(att, dim=-1)
|
65 |
+
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
|
66 |
+
out = self.fc_o(out) # (b_s, nq, d_model)
|
67 |
+
|
68 |
+
return out, att
|
69 |
+
|
70 |
+
|
71 |
+
class MultiheadAttention(nn.Module):
|
72 |
+
|
73 |
+
def __init__(self, d_model = 512, dropout = 0.1, use_aoa = True):
|
74 |
+
super().__init__()
|
75 |
+
self.d_model = d_model
|
76 |
+
self.use_aoa = use_aoa
|
77 |
+
|
78 |
+
self.attention = ScaledDotProduct()
|
79 |
+
self.norm = nn.LayerNorm(d_model)
|
80 |
+
self.dropout = nn.Dropout(dropout)
|
81 |
+
if self.use_aoa:
|
82 |
+
self.infomative_attention = nn.Linear(2 * self.d_model, self.d_model)
|
83 |
+
self.gated_attention = nn.Linear(2 * self.d_model, self.d_model)
|
84 |
+
|
85 |
+
def forward(self, q, k, v, mask = None):
|
86 |
+
out, _ = self.attention(q, k, v, mask)
|
87 |
+
if self.use_aoa:
|
88 |
+
aoa_input = torch.cat([q, out], dim = -1)
|
89 |
+
i = self.infomative_attention(aoa_input)
|
90 |
+
g = torch.sigmoid(self.gated_attention(aoa_input))
|
91 |
+
out = i * g
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class PositionWiseFeedForward(nn.Module):
|
96 |
+
def __init__(self, d_model = 512, d_ff = 2048, dropout = 0.1):
|
97 |
+
super().__init__()
|
98 |
+
self.fc1 = nn.Linear(d_model, d_ff)
|
99 |
+
self.fc2 = nn.Linear(d_ff, d_model)
|
100 |
+
self.relu = nn.ReLU()
|
101 |
+
|
102 |
+
def forward(self, input):
|
103 |
+
out = self.fc1(input)
|
104 |
+
out = self.fc2(self.relu(out))
|
105 |
+
return out
|
106 |
+
|
107 |
+
class AddNorm(nn.Module):
|
108 |
+
def __init__(self, dim = 512, dropout = 0.1):
|
109 |
+
super().__init__()
|
110 |
+
self.dropout = nn.Dropout(dropout)
|
111 |
+
self.norm = nn.LayerNorm(dim)
|
112 |
+
|
113 |
+
def forward(self, x, y):
|
114 |
+
return self.norm(x + self.dropout(y))
|
115 |
+
|
116 |
+
|
117 |
+
class SinusoidPositionalEmbedding(nn.Module):
|
118 |
+
def __init__(self, num_pos_feats=512, temperature=10000, normalize=False, scale=None):
|
119 |
+
super().__init__()
|
120 |
+
self.num_pos_feats = num_pos_feats
|
121 |
+
self.temperature = temperature
|
122 |
+
self.normalize = normalize
|
123 |
+
if scale is not None and normalize is False:
|
124 |
+
raise ValueError("normalize should be True if scale is passed")
|
125 |
+
if scale is None:
|
126 |
+
scale = 2 * math.pi
|
127 |
+
self.scale = scale
|
128 |
+
|
129 |
+
def forward(self, x, mask=None):
|
130 |
+
if mask is None:
|
131 |
+
mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=x.device)
|
132 |
+
not_mask = (mask == False)
|
133 |
+
embed = not_mask.cumsum(1, dtype=torch.float32)
|
134 |
+
if self.normalize:
|
135 |
+
eps = 1e-6
|
136 |
+
embed = embed / (embed[:, -1:] + eps) * self.scale
|
137 |
+
|
138 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
139 |
+
dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / self.num_pos_feats)
|
140 |
+
|
141 |
+
pos = embed[:, :, None] / dim_t
|
142 |
+
pos = torch.stack((pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=-1).flatten(-2)
|
143 |
+
|
144 |
+
return pos
|
145 |
+
|
146 |
+
|
147 |
+
class GuidedEncoderLayer(nn.Module):
|
148 |
+
def __init__(self):
|
149 |
+
super().__init__()
|
150 |
+
self.self_mhatt = MultiheadAttention()
|
151 |
+
self.guided_mhatt = MultiheadAttention()
|
152 |
+
self.pwff = PositionWiseFeedForward()
|
153 |
+
self.first_norm = AddNorm()
|
154 |
+
self.second_norm = AddNorm()
|
155 |
+
self.third_norm = AddNorm()
|
156 |
+
def forward(self, q, k, v, self_mask, guided_mask):
|
157 |
+
self_att = self.self_mhatt(q, q, q, self_mask)
|
158 |
+
self_att = self.first_norm(self_att, q)
|
159 |
+
guided_att = self.guided_mhatt(self_att, k, v, guided_mask)
|
160 |
+
guided_att = self.second_norm(guided_att, self_att)
|
161 |
+
out = self.pwff(guided_att)
|
162 |
+
out = self.third_norm(out, guided_att)
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
class GuidedAttentionEncoder(nn.Module):
|
167 |
+
def __init__(self, num_layers = 2, d_model = 512):
|
168 |
+
super().__init__()
|
169 |
+
self.pos_embedding = SinusoidPositionalEmbedding()
|
170 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
171 |
+
|
172 |
+
self.guided_layers = nn.ModuleList([GuidedEncoderLayer() for _ in range(num_layers)])
|
173 |
+
self.language_layers = nn.ModuleList(GuidedEncoderLayer() for _ in range(num_layers))
|
174 |
+
|
175 |
+
def forward(self, vision_features, vision_mask, language_features, language_mask):
|
176 |
+
vision_features = self.layer_norm(vision_features) + self.pos_embedding(vision_features)
|
177 |
+
language_features = self.layer_norm(language_features) + self.pos_embedding(language_features)
|
178 |
+
|
179 |
+
for layers in zip(self.guided_layers, self.language_layers):
|
180 |
+
guided_layer, language_layer = layers
|
181 |
+
vision_features = guided_layer(q = vision_features,
|
182 |
+
k = language_features,
|
183 |
+
v = language_features,
|
184 |
+
self_mask = vision_mask,
|
185 |
+
guided_mask = language_mask)
|
186 |
+
language_features = language_layer(q = language_features,
|
187 |
+
k = vision_features,
|
188 |
+
v = vision_features,
|
189 |
+
self_mask = language_mask,
|
190 |
+
guided_mask = vision_mask)
|
191 |
+
|
192 |
+
return vision_features, language_features
|
193 |
+
|
194 |
+
|
195 |
+
class VisionEmbedding(nn.Module):
|
196 |
+
def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
|
197 |
+
super().__init__()
|
198 |
+
self.prep = AutoFeatureExtractor.from_pretrained(vision_model_name)
|
199 |
+
self.backbone = AutoModel.from_pretrained(vision_model_name)
|
200 |
+
for param in self.backbone.parameters():
|
201 |
+
param.requires_grad = False
|
202 |
+
|
203 |
+
self.proj = nn.Linear(out_dim, hidden_dim)
|
204 |
+
self.dropout = nn.Dropout(dropout)
|
205 |
+
self.gelu = nn.GELU()
|
206 |
+
def forward(self, images):
|
207 |
+
inputs = self.prep(images = images, return_tensors = "pt").to(device)
|
208 |
+
with torch.no_grad():
|
209 |
+
outputs = self.backbone(**inputs)
|
210 |
+
features = outputs.last_hidden_state
|
211 |
+
vision_mask = generate_padding_mask(features, padding_idx = 0)
|
212 |
+
out = self.proj(features)
|
213 |
+
out = self.gelu(out)
|
214 |
+
out = self.dropout(out)
|
215 |
+
return out, vision_mask
|
216 |
+
|
217 |
+
|
218 |
+
class LanguageEmbedding(nn.Module):
|
219 |
+
def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
|
220 |
+
super().__init__()
|
221 |
+
self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)
|
222 |
+
self.embeding = AutoModel.from_pretrained(language_model_name)
|
223 |
+
for param in self.embeding.parameters():
|
224 |
+
param.requires_grad = False
|
225 |
+
self.proj = nn.Linear(out_dim, hidden_dim)
|
226 |
+
self.dropout = nn.Dropout(dropout)
|
227 |
+
self.gelu = nn.GELU()
|
228 |
+
def forward(self, questions):
|
229 |
+
inputs = self.tokenizer(questions,
|
230 |
+
padding = 'max_length',
|
231 |
+
max_length = 30,
|
232 |
+
truncation = True,
|
233 |
+
return_tensors = 'pt',
|
234 |
+
return_token_type_ids = True,
|
235 |
+
return_attention_mask = True).to(device)
|
236 |
+
|
237 |
+
features = self.embeding(**inputs).last_hidden_state
|
238 |
+
language_mask = generate_padding_mask(inputs.input_ids, padding_idx=self.tokenizer.pad_token_id)
|
239 |
+
out = self.proj(features)
|
240 |
+
out = self.gelu(out)
|
241 |
+
out = self.dropout(out)
|
242 |
+
return out, language_mask
|
243 |
+
|
244 |
+
class BaseModel(nn.Module):
|
245 |
+
def __init__(self, num_classes = 353, d_model = 512):
|
246 |
+
super().__init__()
|
247 |
+
self.vision_embedding = VisionEmbedding()
|
248 |
+
self.language_embedding = LanguageEmbedding()
|
249 |
+
self.encoder = GuidedAttentionEncoder()
|
250 |
+
self.fusion = nn.Sequential(nn.Linear(2 * d_model, d_model),
|
251 |
+
nn.ReLU(),
|
252 |
+
nn.Dropout(0.2))
|
253 |
+
self.classify = nn.Linear(d_model, num_classes)
|
254 |
+
self.attention_weights = nn.Linear(d_model, 1)
|
255 |
+
|
256 |
+
def forward(self, images, questions):
|
257 |
+
embedded_text, text_mask = self.language_embedding(questions)
|
258 |
+
embedded_vision, vison_mask = self.vision_embedding(images)
|
259 |
+
|
260 |
+
encoded_image, encoded_text = self.encoder(embedded_vision, vison_mask,embedded_text, text_mask)
|
261 |
+
text_attended = self.attention_weights(torch.tanh(encoded_text))
|
262 |
+
image_attended = self.attention_weights(torch.tanh(encoded_image))
|
263 |
+
|
264 |
+
attention_weights = torch.softmax(torch.cat([text_attended, image_attended], dim=1), dim=1)
|
265 |
+
|
266 |
+
attended_text = torch.sum(attention_weights[:, 0].unsqueeze(-1) * encoded_text, dim=1)
|
267 |
+
attended_image = torch.sum(attention_weights[:, 1].unsqueeze(-1) * encoded_image, dim=1)
|
268 |
+
|
269 |
+
fused_output = self.fusion(torch.cat([attended_text, attended_image], dim=1))
|
270 |
+
logits = self.classify(fused_output)
|
271 |
+
logits = F.log_softmax(logits, dim=-1)
|
272 |
+
return logits
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
model = BaseModel().to(device)
|
278 |
+
print(model.eval)
|
279 |
+
|
Predict.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Model import BaseModel
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms as T
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
checkpoint = torch.load('Checkpoint/checkpoint.pt')
|
10 |
+
with open('Dataset/answer.json', 'r', encoding = 'utf8') as f:
|
11 |
+
answer_space = json.load(f)
|
12 |
+
swap_space = {v : k for k, v in answer_space.items()}
|
13 |
+
|
14 |
+
|
15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
model = BaseModel().to(device)
|
17 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
18 |
+
|
19 |
+
def generate_caption(image, question):
|
20 |
+
if isinstance(image, np.ndarray):
|
21 |
+
image = Image.fromarray(image)
|
22 |
+
elif isinstance(image, str):
|
23 |
+
image = Image.open(image).convert("RGB")
|
24 |
+
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
|
25 |
+
image = transform(image).unsqueeze(0)
|
26 |
+
with torch.no_grad():
|
27 |
+
logits = model(image, question)
|
28 |
+
idx = torch.argmax(logits)
|
29 |
+
return swap_space[idx.item()]
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
image = 'Dataset/train/68857.jpg'
|
33 |
+
question = 'màu của chiếc bình là gì'
|
34 |
+
pred = generate_caption(image, question)
|
35 |
+
print(pred)
|
last_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8ccdd750d65a5f4c508d8b0702ee792f9e27f9be52ffa5c42b98c15420c1c9d1
|
3 |
+
size 1105547492
|