Delete generate_model.py
Browse files- generate_model.py +0 -730
generate_model.py
DELETED
@@ -1,730 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import time
|
3 |
-
import logging
|
4 |
-
import requests
|
5 |
-
import os
|
6 |
-
from PIL import Image
|
7 |
-
from io import BytesIO
|
8 |
-
|
9 |
-
from PIL import Image
|
10 |
-
import torch
|
11 |
-
from transformers import AutoTokenizer
|
12 |
-
|
13 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
14 |
-
|
15 |
-
from PIL import Image
|
16 |
-
from io import BytesIO
|
17 |
-
import base64
|
18 |
-
|
19 |
-
import torch
|
20 |
-
from transformers import StoppingCriteria
|
21 |
-
|
22 |
-
import math
|
23 |
-
import ast
|
24 |
-
|
25 |
-
# Model Constants
|
26 |
-
IGNORE_INDEX = -100
|
27 |
-
IMAGE_TOKEN_INDEX = -200
|
28 |
-
DEFAULT_IMAGE_TOKEN = "<image>"
|
29 |
-
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
30 |
-
DEFAULT_IM_START_TOKEN = "<im_start>"
|
31 |
-
DEFAULT_IM_END_TOKEN = "<im_end>"
|
32 |
-
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
33 |
-
import dataclasses
|
34 |
-
from enum import auto, Enum
|
35 |
-
from typing import List, Tuple
|
36 |
-
|
37 |
-
|
38 |
-
class SeparatorStyle(Enum):
|
39 |
-
"""Different separator style."""
|
40 |
-
SINGLE = auto()
|
41 |
-
TWO = auto()
|
42 |
-
MPT = auto()
|
43 |
-
PLAIN = auto()
|
44 |
-
LLAMA_2 = auto()
|
45 |
-
TINY_LLAMA = auto()
|
46 |
-
QWEN_2 = auto()
|
47 |
-
|
48 |
-
|
49 |
-
@dataclasses.dataclass
|
50 |
-
class Conversation:
|
51 |
-
"""A class that keeps all conversation history."""
|
52 |
-
system: str
|
53 |
-
roles: List[str]
|
54 |
-
messages: List[List[str]]
|
55 |
-
offset: int
|
56 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
57 |
-
sep: str = "###"
|
58 |
-
sep2: str = None
|
59 |
-
version: str = "Unknown"
|
60 |
-
|
61 |
-
skip_next: bool = False
|
62 |
-
|
63 |
-
def get_prompt(self):
|
64 |
-
messages = self.messages
|
65 |
-
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
66 |
-
messages = self.messages.copy()
|
67 |
-
init_role, init_msg = messages[0].copy()
|
68 |
-
init_msg = init_msg[0].replace("<image>", "").strip()
|
69 |
-
if 'mmtag' in self.version:
|
70 |
-
messages[0] = (init_role, init_msg)
|
71 |
-
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
72 |
-
messages.insert(1, (self.roles[1], "Received."))
|
73 |
-
else:
|
74 |
-
messages[0] = (init_role, "<image>\n" + init_msg)
|
75 |
-
|
76 |
-
if self.sep_style == SeparatorStyle.SINGLE:
|
77 |
-
ret = self.system + self.sep
|
78 |
-
for role, message in messages:
|
79 |
-
if message:
|
80 |
-
if type(message) is tuple:
|
81 |
-
message, _, _ = message
|
82 |
-
ret += role + ": " + message + self.sep
|
83 |
-
else:
|
84 |
-
ret += role + ":"
|
85 |
-
elif self.sep_style == SeparatorStyle.TWO:
|
86 |
-
seps = [self.sep, self.sep2]
|
87 |
-
ret = self.system + seps[0]
|
88 |
-
for i, (role, message) in enumerate(messages):
|
89 |
-
if message:
|
90 |
-
if type(message) is tuple:
|
91 |
-
message, _, _ = message
|
92 |
-
ret += role + ": " + message + seps[i % 2]
|
93 |
-
else:
|
94 |
-
ret += role + ":"
|
95 |
-
elif self.sep_style == SeparatorStyle.MPT:
|
96 |
-
ret = self.system + self.sep
|
97 |
-
for role, message in messages:
|
98 |
-
if message:
|
99 |
-
if type(message) is tuple:
|
100 |
-
message, _, _ = message
|
101 |
-
ret += role + message + self.sep
|
102 |
-
else:
|
103 |
-
ret += role
|
104 |
-
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
105 |
-
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
106 |
-
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
107 |
-
ret = ""
|
108 |
-
|
109 |
-
for i, (role, message) in enumerate(messages):
|
110 |
-
if i == 0:
|
111 |
-
assert message, "first message should not be none"
|
112 |
-
assert role == self.roles[0], "first message should come from user"
|
113 |
-
if message:
|
114 |
-
if type(message) is tuple:
|
115 |
-
message, _, _ = message
|
116 |
-
if i == 0: message = wrap_sys(self.system) + message
|
117 |
-
if i % 2 == 0:
|
118 |
-
message = wrap_inst(message)
|
119 |
-
ret += self.sep + message
|
120 |
-
else:
|
121 |
-
ret += " " + message + " " + self.sep2
|
122 |
-
else:
|
123 |
-
ret += ""
|
124 |
-
ret = ret.lstrip(self.sep)
|
125 |
-
elif self.sep_style == SeparatorStyle.TINY_LLAMA:
|
126 |
-
sep = "</s>"
|
127 |
-
wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
|
128 |
-
wrap_user = lambda msg: f"<|user|>\n{msg}\n"
|
129 |
-
wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
|
130 |
-
ret = ""
|
131 |
-
|
132 |
-
for i, (role, message) in enumerate(messages):
|
133 |
-
if i == 0:
|
134 |
-
assert message, "first message should not be none"
|
135 |
-
assert role == self.roles[0], "first message should come from user"
|
136 |
-
if message:
|
137 |
-
if type(message) is tuple:
|
138 |
-
message, _, _ = message
|
139 |
-
if i % 2 == 0:
|
140 |
-
message = wrap_user(message)
|
141 |
-
if i == 0:
|
142 |
-
message = wrap_sys(self.system) + message
|
143 |
-
ret += self.sep + message
|
144 |
-
else:
|
145 |
-
message = wrap_assistant(message) + self.sep2
|
146 |
-
ret += message
|
147 |
-
else:
|
148 |
-
ret += "<|assistant|>\n"
|
149 |
-
ret = ret.lstrip(self.sep)
|
150 |
-
elif self.sep_style == SeparatorStyle.QWEN_2:
|
151 |
-
ret = self.system + self.sep
|
152 |
-
for role, message in messages:
|
153 |
-
if message:
|
154 |
-
if type(message) is tuple:
|
155 |
-
message, _, _ = message
|
156 |
-
ret += role + message + self.sep
|
157 |
-
else:
|
158 |
-
ret += role
|
159 |
-
elif self.sep_style == SeparatorStyle.PLAIN:
|
160 |
-
seps = [self.sep, self.sep2]
|
161 |
-
ret = self.system
|
162 |
-
for i, (role, message) in enumerate(messages):
|
163 |
-
if message:
|
164 |
-
if type(message) is tuple:
|
165 |
-
message, _, _ = message
|
166 |
-
ret += message + seps[i % 2]
|
167 |
-
else:
|
168 |
-
ret += ""
|
169 |
-
else:
|
170 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
171 |
-
|
172 |
-
return ret
|
173 |
-
|
174 |
-
def append_message(self, role, message):
|
175 |
-
self.messages.append([role, message])
|
176 |
-
|
177 |
-
def get_images(self, return_pil=False):
|
178 |
-
images = []
|
179 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
180 |
-
if i % 2 == 0:
|
181 |
-
if type(msg) is tuple:
|
182 |
-
import base64
|
183 |
-
from io import BytesIO
|
184 |
-
from PIL import Image
|
185 |
-
msg, image, image_process_mode = msg
|
186 |
-
if image_process_mode == "Pad":
|
187 |
-
def expand2square(pil_img, background_color=(122, 116, 104)):
|
188 |
-
width, height = pil_img.size
|
189 |
-
if width == height:
|
190 |
-
return pil_img
|
191 |
-
elif width > height:
|
192 |
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
193 |
-
result.paste(pil_img, (0, (width - height) // 2))
|
194 |
-
return result
|
195 |
-
else:
|
196 |
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
197 |
-
result.paste(pil_img, ((height - width) // 2, 0))
|
198 |
-
return result
|
199 |
-
image = expand2square(image)
|
200 |
-
elif image_process_mode in ["Default", "Crop"]:
|
201 |
-
pass
|
202 |
-
elif image_process_mode == "Resize":
|
203 |
-
image = image.resize((336, 336))
|
204 |
-
else:
|
205 |
-
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
206 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
207 |
-
aspect_ratio = max_hw / min_hw
|
208 |
-
max_len, min_len = 800, 400
|
209 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
210 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
211 |
-
W, H = image.size
|
212 |
-
if longest_edge != max(image.size):
|
213 |
-
if H > W:
|
214 |
-
H, W = longest_edge, shortest_edge
|
215 |
-
else:
|
216 |
-
H, W = shortest_edge, longest_edge
|
217 |
-
image = image.resize((W, H))
|
218 |
-
if return_pil:
|
219 |
-
images.append(image)
|
220 |
-
else:
|
221 |
-
buffered = BytesIO()
|
222 |
-
image.save(buffered, format="PNG")
|
223 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
224 |
-
images.append(img_b64_str)
|
225 |
-
return images
|
226 |
-
|
227 |
-
def to_gradio_chatbot(self):
|
228 |
-
ret = []
|
229 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
230 |
-
if i % 2 == 0:
|
231 |
-
if type(msg) is tuple:
|
232 |
-
import base64
|
233 |
-
from io import BytesIO
|
234 |
-
msg, image, image_process_mode = msg
|
235 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
236 |
-
aspect_ratio = max_hw / min_hw
|
237 |
-
max_len, min_len = 800, 400
|
238 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
239 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
240 |
-
W, H = image.size
|
241 |
-
if H > W:
|
242 |
-
H, W = longest_edge, shortest_edge
|
243 |
-
else:
|
244 |
-
H, W = shortest_edge, longest_edge
|
245 |
-
image = image.resize((W, H))
|
246 |
-
buffered = BytesIO()
|
247 |
-
image.save(buffered, format="JPEG")
|
248 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
249 |
-
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
250 |
-
msg = img_str + msg.replace('<image>', '').strip()
|
251 |
-
ret.append([msg, None])
|
252 |
-
else:
|
253 |
-
ret.append([msg, None])
|
254 |
-
else:
|
255 |
-
ret[-1][-1] = msg
|
256 |
-
return ret
|
257 |
-
|
258 |
-
def copy(self):
|
259 |
-
return Conversation(
|
260 |
-
system=self.system,
|
261 |
-
roles=self.roles,
|
262 |
-
messages=[[x, y] for x, y in self.messages],
|
263 |
-
offset=self.offset,
|
264 |
-
sep_style=self.sep_style,
|
265 |
-
sep=self.sep,
|
266 |
-
sep2=self.sep2,
|
267 |
-
version=self.version)
|
268 |
-
|
269 |
-
def dict(self):
|
270 |
-
if len(self.get_images()) > 0:
|
271 |
-
return {
|
272 |
-
"system": self.system,
|
273 |
-
"roles": self.roles,
|
274 |
-
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
275 |
-
"offset": self.offset,
|
276 |
-
"sep": self.sep,
|
277 |
-
"sep2": self.sep2,
|
278 |
-
}
|
279 |
-
return {
|
280 |
-
"system": self.system,
|
281 |
-
"roles": self.roles,
|
282 |
-
"messages": self.messages,
|
283 |
-
"offset": self.offset,
|
284 |
-
"sep": self.sep,
|
285 |
-
"sep2": self.sep2,
|
286 |
-
}
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
conv_phi_v0 = Conversation(
|
292 |
-
system="A chat between a curious user and an artificial intelligence assistant. "
|
293 |
-
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
294 |
-
roles=("USER", "ASSISTANT"),
|
295 |
-
version="phi",
|
296 |
-
messages=(),
|
297 |
-
offset=0,
|
298 |
-
sep_style=SeparatorStyle.TWO,
|
299 |
-
sep=" ",
|
300 |
-
sep2="<|endoftext|>",
|
301 |
-
)
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
def select_best_resolution(original_size, possible_resolutions):
|
306 |
-
"""
|
307 |
-
Selects the best resolution from a list of possible resolutions based on the original size.
|
308 |
-
|
309 |
-
Args:
|
310 |
-
original_size (tuple): The original size of the image in the format (width, height).
|
311 |
-
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
312 |
-
|
313 |
-
Returns:
|
314 |
-
tuple: The best fit resolution in the format (width, height).
|
315 |
-
"""
|
316 |
-
original_width, original_height = original_size
|
317 |
-
best_fit = None
|
318 |
-
max_effective_resolution = 0
|
319 |
-
min_wasted_resolution = float('inf')
|
320 |
-
|
321 |
-
for width, height in possible_resolutions:
|
322 |
-
scale = min(width / original_width, height / original_height)
|
323 |
-
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
324 |
-
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
325 |
-
wasted_resolution = (width * height) - effective_resolution
|
326 |
-
|
327 |
-
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
328 |
-
max_effective_resolution = effective_resolution
|
329 |
-
min_wasted_resolution = wasted_resolution
|
330 |
-
best_fit = (width, height)
|
331 |
-
|
332 |
-
return best_fit
|
333 |
-
|
334 |
-
|
335 |
-
## added by llava-1.6
|
336 |
-
def resize_and_pad_image(image, target_resolution):
|
337 |
-
"""
|
338 |
-
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
339 |
-
|
340 |
-
Args:
|
341 |
-
image (PIL.Image.Image): The input image.
|
342 |
-
target_resolution (tuple): The target resolution (width, height) of the image.
|
343 |
-
|
344 |
-
Returns:
|
345 |
-
PIL.Image.Image: The resized and padded image.
|
346 |
-
"""
|
347 |
-
original_width, original_height = image.size
|
348 |
-
target_width, target_height = target_resolution
|
349 |
-
|
350 |
-
scale_w = target_width / original_width
|
351 |
-
scale_h = target_height / original_height
|
352 |
-
|
353 |
-
if scale_w < scale_h:
|
354 |
-
new_width = target_width
|
355 |
-
new_height = min(math.ceil(original_height * scale_w), target_height)
|
356 |
-
else:
|
357 |
-
new_height = target_height
|
358 |
-
new_width = min(math.ceil(original_width * scale_h), target_width)
|
359 |
-
|
360 |
-
# Resize the image
|
361 |
-
resized_image = image.resize((new_width, new_height))
|
362 |
-
|
363 |
-
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
364 |
-
paste_x = (target_width - new_width) // 2
|
365 |
-
paste_y = (target_height - new_height) // 2
|
366 |
-
new_image.paste(resized_image, (paste_x, paste_y))
|
367 |
-
|
368 |
-
return new_image
|
369 |
-
|
370 |
-
|
371 |
-
## added by llava-1.6
|
372 |
-
def divide_to_patches(image, patch_size):
|
373 |
-
"""
|
374 |
-
Divides an image into patches of a specified size.
|
375 |
-
|
376 |
-
Args:
|
377 |
-
image (PIL.Image.Image): The input image.
|
378 |
-
patch_size (int): The size of each patch.
|
379 |
-
|
380 |
-
Returns:
|
381 |
-
list: A list of PIL.Image.Image objects representing the patches.
|
382 |
-
"""
|
383 |
-
patches = []
|
384 |
-
width, height = image.size
|
385 |
-
for i in range(0, height, patch_size):
|
386 |
-
for j in range(0, width, patch_size):
|
387 |
-
box = (j, i, j + patch_size, i + patch_size)
|
388 |
-
patch = image.crop(box)
|
389 |
-
patches.append(patch)
|
390 |
-
|
391 |
-
return patches
|
392 |
-
|
393 |
-
|
394 |
-
## added by llava-1.6
|
395 |
-
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
396 |
-
"""
|
397 |
-
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
398 |
-
|
399 |
-
Args:
|
400 |
-
image_size (tuple): The size of the input image in the format (width, height).
|
401 |
-
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
402 |
-
patch_size (int): The size of each image patch.
|
403 |
-
|
404 |
-
Returns:
|
405 |
-
tuple: The shape of the image patch grid in the format (width, height).
|
406 |
-
"""
|
407 |
-
if type(grid_pinpoints) is list:
|
408 |
-
possible_resolutions = grid_pinpoints
|
409 |
-
else:
|
410 |
-
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
411 |
-
width, height = select_best_resolution(image_size, possible_resolutions)
|
412 |
-
return width // patch_size, height // patch_size
|
413 |
-
|
414 |
-
|
415 |
-
## added by llava-1.6
|
416 |
-
def process_anyres_image(image, processor, grid_pinpoints):
|
417 |
-
"""
|
418 |
-
Process an image with variable resolutions.
|
419 |
-
|
420 |
-
Args:
|
421 |
-
image (PIL.Image.Image): The input image to be processed.
|
422 |
-
processor: The image processor object.
|
423 |
-
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
424 |
-
|
425 |
-
Returns:
|
426 |
-
torch.Tensor: A tensor containing the processed image patches.
|
427 |
-
"""
|
428 |
-
if type(grid_pinpoints) is list:
|
429 |
-
possible_resolutions = grid_pinpoints
|
430 |
-
else:
|
431 |
-
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
432 |
-
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
433 |
-
image_padded = resize_and_pad_image(image, best_resolution)
|
434 |
-
|
435 |
-
patches = divide_to_patches(image_padded, processor.crop_size['height'])
|
436 |
-
|
437 |
-
image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
438 |
-
|
439 |
-
image_patches = [image_original_resize] + patches
|
440 |
-
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
441 |
-
for image_patch in image_patches]
|
442 |
-
return torch.stack(image_patches, dim=0)
|
443 |
-
|
444 |
-
|
445 |
-
def load_image_from_base64(image):
|
446 |
-
return Image.open(BytesIO(base64.b64decode(image)))
|
447 |
-
|
448 |
-
|
449 |
-
def expand2square(pil_img, background_color):
|
450 |
-
width, height = pil_img.size
|
451 |
-
if width == height:
|
452 |
-
return pil_img
|
453 |
-
elif width > height:
|
454 |
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
455 |
-
result.paste(pil_img, (0, (width - height) // 2))
|
456 |
-
return result
|
457 |
-
else:
|
458 |
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
459 |
-
result.paste(pil_img, ((height - width) // 2, 0))
|
460 |
-
return result
|
461 |
-
|
462 |
-
|
463 |
-
def process_images(images, image_processor, model_cfg):
|
464 |
-
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
465 |
-
new_images = []
|
466 |
-
if image_aspect_ratio == 'pad':
|
467 |
-
for image in images:
|
468 |
-
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
469 |
-
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
470 |
-
new_images.append(image)
|
471 |
-
elif image_aspect_ratio == "anyres":
|
472 |
-
for image in images:
|
473 |
-
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
474 |
-
new_images.append(image)
|
475 |
-
else:
|
476 |
-
return image_processor(images, return_tensors='pt')['pixel_values']
|
477 |
-
if all(x.shape == new_images[0].shape for x in new_images):
|
478 |
-
new_images = torch.stack(new_images, dim=0)
|
479 |
-
return new_images
|
480 |
-
|
481 |
-
|
482 |
-
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
483 |
-
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
484 |
-
|
485 |
-
def insert_separator(X, sep):
|
486 |
-
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
487 |
-
|
488 |
-
input_ids = []
|
489 |
-
offset = 0
|
490 |
-
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
491 |
-
offset = 1
|
492 |
-
input_ids.append(prompt_chunks[0][0])
|
493 |
-
|
494 |
-
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
495 |
-
input_ids.extend(x[offset:])
|
496 |
-
|
497 |
-
if return_tensors is not None:
|
498 |
-
if return_tensors == 'pt':
|
499 |
-
return torch.tensor(input_ids, dtype=torch.long)
|
500 |
-
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
501 |
-
return input_ids
|
502 |
-
|
503 |
-
|
504 |
-
def get_model_name_from_path(model_path):
|
505 |
-
model_path = model_path.strip("/")
|
506 |
-
model_paths = model_path.split("/")
|
507 |
-
if model_paths[-1].startswith('checkpoint-'):
|
508 |
-
return model_paths[-2] + "_" + model_paths[-1]
|
509 |
-
else:
|
510 |
-
return model_paths[-1]
|
511 |
-
|
512 |
-
|
513 |
-
class KeywordsStoppingCriteria(StoppingCriteria):
|
514 |
-
def __init__(self, keywords, tokenizer, input_ids):
|
515 |
-
self.keywords = keywords
|
516 |
-
self.keyword_ids = []
|
517 |
-
self.max_keyword_len = 0
|
518 |
-
for keyword in keywords:
|
519 |
-
cur_keyword_ids = tokenizer(keyword).input_ids
|
520 |
-
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
521 |
-
cur_keyword_ids = cur_keyword_ids[1:]
|
522 |
-
if len(cur_keyword_ids) > self.max_keyword_len:
|
523 |
-
self.max_keyword_len = len(cur_keyword_ids)
|
524 |
-
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
525 |
-
self.tokenizer = tokenizer
|
526 |
-
self.start_len = input_ids.shape[1]
|
527 |
-
|
528 |
-
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
529 |
-
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
530 |
-
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
531 |
-
for keyword_id in self.keyword_ids:
|
532 |
-
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
533 |
-
return True
|
534 |
-
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
535 |
-
for keyword in self.keywords:
|
536 |
-
if keyword in outputs:
|
537 |
-
return True
|
538 |
-
return False
|
539 |
-
|
540 |
-
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
541 |
-
outputs = []
|
542 |
-
for i in range(output_ids.shape[0]):
|
543 |
-
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
544 |
-
return all(outputs)
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
def load_image(image_file):
|
549 |
-
if image_file.startswith("http") or image_file.startswith("https"):
|
550 |
-
response = requests.get(image_file)
|
551 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
552 |
-
else:
|
553 |
-
image = Image.open(image_file).convert("RGB")
|
554 |
-
return image
|
555 |
-
|
556 |
-
|
557 |
-
def generate(
|
558 |
-
prompt: str,
|
559 |
-
model: str,
|
560 |
-
tokenizer = None,
|
561 |
-
image: str = None,
|
562 |
-
device: str = None,
|
563 |
-
max_new_tokens: int = 1024,
|
564 |
-
num_beams = 1,
|
565 |
-
top_p=None,
|
566 |
-
temperature=0.2
|
567 |
-
):
|
568 |
-
if not device:
|
569 |
-
if torch.cuda.is_available() and torch.cuda.device_count():
|
570 |
-
device = "cuda:0"
|
571 |
-
logging.warning(
|
572 |
-
'inference device is not set, using cuda:0, %s',
|
573 |
-
torch.cuda.get_device_name(0)
|
574 |
-
)
|
575 |
-
else:
|
576 |
-
device = 'cpu'
|
577 |
-
logging.warning(
|
578 |
-
(
|
579 |
-
'No CUDA device detected, using cpu, '
|
580 |
-
'expect slower speeds.'
|
581 |
-
)
|
582 |
-
)
|
583 |
-
|
584 |
-
if 'cuda' in device and not torch.cuda.is_available():
|
585 |
-
raise ValueError('CUDA device requested but no CUDA device detected.')
|
586 |
-
|
587 |
-
if isinstance(model, str):
|
588 |
-
checkpoint_path = model
|
589 |
-
# print(f'loading model from {checkpoint_path}...')
|
590 |
-
model = AutoModelForCausalLM.from_pretrained(
|
591 |
-
checkpoint_path,
|
592 |
-
trust_remote_code=True
|
593 |
-
)
|
594 |
-
# print('model load over')
|
595 |
-
config = model.config
|
596 |
-
if tokenizer is None:
|
597 |
-
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length,
|
598 |
-
padding_side = config.tokenizer_padding_side)
|
599 |
-
image_processor = model.vision_tower._image_processor
|
600 |
-
context_len = getattr(config, 'max_sequence_length', 2048)
|
601 |
-
model.to(device).eval()
|
602 |
-
|
603 |
-
|
604 |
-
if image is not None:
|
605 |
-
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
|
606 |
-
conv = conv_phi_v0.copy()
|
607 |
-
conv.append_message(conv.roles[0], prompt)
|
608 |
-
conv.append_message(conv.roles[1], None)
|
609 |
-
prompt = conv.get_prompt()
|
610 |
-
if image is not None:
|
611 |
-
# print('loading image...')
|
612 |
-
image = load_image(image)
|
613 |
-
# print('load image over')
|
614 |
-
image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16)
|
615 |
-
|
616 |
-
input_ids = (
|
617 |
-
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
618 |
-
.unsqueeze(0)
|
619 |
-
.to(model.device, dtype=torch.float16)
|
620 |
-
)
|
621 |
-
# Generate
|
622 |
-
stime = time.time()
|
623 |
-
# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
624 |
-
# keywords = [stop_str]
|
625 |
-
# stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
626 |
-
# print('start inference...')
|
627 |
-
with torch.inference_mode():
|
628 |
-
output_ids = model.generate(
|
629 |
-
input_ids,
|
630 |
-
images=image_tensor,
|
631 |
-
do_sample=True if temperature > 0 else False,
|
632 |
-
temperature=temperature,
|
633 |
-
top_p=top_p,
|
634 |
-
num_beams=num_beams,
|
635 |
-
pad_token_id=tokenizer.pad_token_id,
|
636 |
-
max_new_tokens=max_new_tokens,
|
637 |
-
use_cache=True,
|
638 |
-
# stopping_criteria=[stopping_criteria],
|
639 |
-
)
|
640 |
-
|
641 |
-
# print('inference over')
|
642 |
-
generation_time = time.time() - stime
|
643 |
-
outputs = tokenizer.batch_decode(
|
644 |
-
output_ids, skip_special_tokens=True
|
645 |
-
)[0]
|
646 |
-
# outputs = outputs.strip()
|
647 |
-
# if outputs.endswith(stop_str):
|
648 |
-
# outputs = outputs[: -len(stop_str)]
|
649 |
-
outputs = outputs.strip()
|
650 |
-
|
651 |
-
return outputs, generation_time
|
652 |
-
def tinyllava_elm_generate_parser():
|
653 |
-
"""Argument Parser"""
|
654 |
-
|
655 |
-
class KwargsParser(argparse.Action):
|
656 |
-
"""Parser action class to parse kwargs of form key=value"""
|
657 |
-
def __call__(self, parser, namespace, values, option_string=None):
|
658 |
-
setattr(namespace, self.dest, dict())
|
659 |
-
for val in values:
|
660 |
-
if '=' not in val:
|
661 |
-
raise ValueError(
|
662 |
-
(
|
663 |
-
'Argument parsing error, kwargs are expected in'
|
664 |
-
' the form of key=value.'
|
665 |
-
)
|
666 |
-
)
|
667 |
-
kwarg_k, kwarg_v = val.split('=')
|
668 |
-
try:
|
669 |
-
converted_v = int(kwarg_v)
|
670 |
-
except ValueError:
|
671 |
-
try:
|
672 |
-
converted_v = float(kwarg_v)
|
673 |
-
except ValueError:
|
674 |
-
converted_v = kwarg_v
|
675 |
-
getattr(namespace, self.dest)[kwarg_k] = converted_v
|
676 |
-
|
677 |
-
parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module')
|
678 |
-
parser.add_argument(
|
679 |
-
'--model',
|
680 |
-
dest='model',
|
681 |
-
help='Path to the hf converted model.',
|
682 |
-
required=True,
|
683 |
-
type=str,
|
684 |
-
)
|
685 |
-
parser.add_argument(
|
686 |
-
'--prompt',
|
687 |
-
dest='prompt',
|
688 |
-
help='Prompt for LLM call.',
|
689 |
-
default='',
|
690 |
-
type=str,
|
691 |
-
)
|
692 |
-
parser.add_argument(
|
693 |
-
'--device',
|
694 |
-
dest='device',
|
695 |
-
help='Device used for inference.',
|
696 |
-
type=str,
|
697 |
-
)
|
698 |
-
parser.add_argument("--image", type=str, default=None)
|
699 |
-
parser.add_argument("--temperature", type=float, default=0)
|
700 |
-
parser.add_argument("--top_p", type=float, default=None)
|
701 |
-
parser.add_argument("--num_beams", type=int, default=1)
|
702 |
-
parser.add_argument("--max_new_tokens", type=int, default=512)
|
703 |
-
return parser.parse_args()
|
704 |
-
|
705 |
-
|
706 |
-
if __name__ == '__main__':
|
707 |
-
args = tinyllava_elm_generate_parser()
|
708 |
-
|
709 |
-
output_text, genertaion_time = generate(
|
710 |
-
prompt=args.prompt,
|
711 |
-
image=args.image,
|
712 |
-
model=args.model,
|
713 |
-
device=args.device,
|
714 |
-
max_new_tokens = args.max_new_tokens,
|
715 |
-
num_beams = args.num_beams,
|
716 |
-
top_p=args.top_p,
|
717 |
-
temperature=args.temperature
|
718 |
-
)
|
719 |
-
|
720 |
-
print_txt = (
|
721 |
-
f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
|
722 |
-
'\033[1m Prompt + Generated Output\033[0m\r\n'
|
723 |
-
f'{"-" * os.get_terminal_size().columns}\r\n'
|
724 |
-
f'{output_text}\r\n'
|
725 |
-
f'{"-" * os.get_terminal_size().columns}\r\n'
|
726 |
-
'\r\nGeneration took'
|
727 |
-
f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
|
728 |
-
'seconds.\r\n'
|
729 |
-
)
|
730 |
-
print(print_txt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|