Spaces:
Runtime error
Runtime error
# Vision Search Assistant: Empower Vision-Language Models as Multimodal Search Engines | |
# Github source: https://github.com/cnzzx/VSA-dev | |
# Licensed under The Apache License 2.0 License [see LICENSE for details] | |
# Based on LLaVA and MindSearch code bases | |
# https://github.com/haotian-liu/LLaVA | |
# https://github.com/IDEA-Research/GroundingDINO | |
# https://github.com/InternLM/MindSearch | |
# -------------------------------------------------------- | |
import os | |
import copy | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
from .vsa_prompt import COCO_CLASSES, get_caption_prompt, get_correlate_prompt, get_qa_prompt | |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from llava.conversation import conv_templates, SeparatorStyle | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path | |
from datetime import datetime | |
from lagent.actions import ActionExecutor, BingBrowser | |
from lagent.llms import INTERNLM2_META, LMDeployServer, LMDeployPipeline | |
from lagent.schema import AgentReturn, AgentStatusCode | |
from lagent.schema import AgentStatusCode | |
from .search_agent.mindsearch_agent import ( | |
MindSearchAgent, SimpleSearchAgent, MindSearchProtocol | |
) | |
from .search_agent.mindsearch_prompt import ( | |
FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN, | |
searcher_context_template_cn, searcher_context_template_en, | |
searcher_input_template_cn, searcher_input_template_en, | |
searcher_system_prompt_cn, searcher_system_prompt_en | |
) | |
from lmdeploy.messages import PytorchEngineConfig | |
from typing import List, Union | |
SEARCH_MODEL_NAMES = { | |
'internlm2_5-7b-chat': 'internlm2', | |
'internlm2_5-1_8b-chat': 'internlm2' | |
} | |
def render_bboxes(in_image: Image.Image, bboxes: np.ndarray, labels: List[str]): | |
out_image = copy.deepcopy(in_image) | |
draw = ImageDraw.Draw(out_image) | |
font = ImageFont.truetype(font = 'assets/Arial.ttf', size = min(in_image.width, in_image.height) // 30) | |
line_width = min(in_image.width, in_image.height) // 100 | |
for i in range(bboxes.shape[0]): | |
draw.rectangle((bboxes[i, 0], bboxes[i, 1], bboxes[i, 2], bboxes[i, 3]), outline=(0, 255, 0), width=line_width) | |
bbox = draw.textbbox((bboxes[i, 0], bboxes[i, 1]), '[Area {}] '.format(i), font=font) | |
draw.rectangle(bbox, fill='white') | |
draw.text((bboxes[i, 0], bboxes[i, 1]), '[Area {}] '.format(i), fill='black', font=font) | |
if bboxes.shape[0] == 0: | |
draw.rectangle((0, 0, in_image.width, in_image.height), outline=(0, 255, 0), width=line_width) | |
bbox = draw.textbbox((0, 0), '[Area {}] '.format(0), font=font) | |
draw.rectangle(bbox, fill='white') | |
draw.text((0, 0), '[Area {}] '.format(0), fill='black', font=font) | |
return out_image | |
class VisualGrounder: | |
def __init__( | |
self, | |
model_path: str = "IDEA-Research/grounding-dino-base", | |
device: str = "cuda:1", | |
box_threshold: float = 0.4, | |
text_threshold: float = 0.3, | |
): | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_path).to(device) | |
self.device = device | |
self.default_classes = COCO_CLASSES | |
self.box_threshold = box_threshold | |
self.text_threshold = text_threshold | |
def __call__( | |
self, | |
in_image: Image.Image, | |
classes: Union[List[str], None] = None, | |
): | |
# Save image. | |
in_image.save('temp/in_image.jpg') | |
# Preparation. | |
if classes is None: | |
classes = self.default_classes | |
text = ". ".join(classes) | |
inputs = self.processor(images=in_image, text=text, return_tensors="pt").to(self.device) | |
# Grounding. | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
# Postprocess | |
results = self.processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold = self.box_threshold, | |
text_threshold = self.text_threshold, | |
target_sizes=[in_image.size[::-1]] | |
) | |
bboxes = results[0]['boxes'].cpu().numpy() | |
labels = results[0]['labels'] | |
print(results) | |
# Visualization. | |
out_image = render_bboxes(in_image, bboxes, labels) | |
out_image.save('temp/ground_bbox.jpg') | |
return bboxes, labels, out_image | |
class VLM: | |
def __init__( | |
self, | |
model_path: str = "liuhaotian/llava-v1.6-vicuna-7b", | |
device: str = "cuda:2", | |
load_8bit: bool = False, | |
load_4bit: bool = True, | |
temperature: float = 0.2, | |
max_new_tokens: int = 1024, | |
): | |
disable_torch_init() | |
model_name = get_model_name_from_path(model_path) | |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( | |
model_path, None, model_name, load_8bit, load_4bit, device=device | |
) | |
self.device = device | |
if "llama-2" in model_name.lower(): | |
self.conv_mode = "llava_llama_2" | |
elif "mistral" in model_name.lower(): | |
self.conv_mode = "mistral_instruct" | |
elif "v1.6-34b" in model_name.lower(): | |
self.conv_mode = "chatml_direct" | |
elif "v1" in model_name.lower(): | |
self.conv_mode = "llava_v1" | |
elif "mpt" in model_name.lower(): | |
self.conv_mode = "mpt" | |
else: | |
self.conv_mode = "llava_v0" | |
self.temperature = temperature | |
self.max_new_tokens = max_new_tokens | |
def __call__( | |
self, | |
image: Image.Image, | |
text: str, | |
): | |
image_size = image.size | |
image_tensor = process_images([image], self.image_processor, self.model.config) | |
if type(image_tensor) is list: | |
image_tensor = [image.to(self.device, dtype=torch.float16) for image in image_tensor] | |
else: | |
image_tensor = image_tensor.to(self.device, dtype=torch.float16) | |
if image is not None: | |
# first message | |
if self.model.config.mm_use_im_start_end: | |
text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text | |
else: | |
text = DEFAULT_IMAGE_TOKEN + '\n' + text | |
image = None | |
conv = conv_templates[self.conv_mode].copy() | |
conv.append_message(conv.roles[0], text) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) | |
with torch.inference_mode(): | |
output_ids = self.model.generate( | |
input_ids, | |
images = image_tensor, | |
image_sizes = [image_size], | |
do_sample = True if self.temperature > 0 else False, | |
temperature = self.temperature, | |
max_new_tokens = self.max_new_tokens, | |
streamer = None, | |
use_cache = True) | |
outputs = self.tokenizer.decode(output_ids[0]).strip() | |
outputs = outputs.replace('<s>', '').replace('</s>', '').replace('"', "'") | |
return outputs | |
class WebSearcher: | |
def __init__( | |
self, | |
model_path: str = 'internlm/internlm2_5-7b-chat', | |
lang: str = 'cn', | |
top_p: float = 0.8, | |
top_k: int = 1, | |
temperature: float = 0, | |
max_new_tokens: int = 8192, | |
repetition_penalty: float = 1.02, | |
max_turn: int = 10, | |
): | |
model_name = get_model_name_from_path(model_path) | |
if model_name in SEARCH_MODEL_NAMES: | |
model_name = SEARCH_MODEL_NAMES[model_name] | |
else: | |
raise Exception('Unsupported model for web searcher.') | |
self.lang = lang | |
backend_config = PytorchEngineConfig( | |
max_batch_size = 1, | |
) | |
llm = LMDeployServer( | |
path = model_path, | |
model_name = model_name, | |
meta_template = INTERNLM2_META, | |
top_p = top_p, | |
top_k = top_k, | |
temperature = temperature, | |
max_new_tokens = max_new_tokens, | |
repetition_penalty = repetition_penalty, | |
stop_words = ['<|im_end|>'], | |
serve_cfg = dict( | |
backend_config = backend_config | |
) | |
) | |
# llm = LMDeployPipeline( | |
# path = model_path, | |
# model_name = model_name, | |
# meta_template = INTERNLM2_META, | |
# top_p = top_p, | |
# top_k = top_k, | |
# temperature = temperature, | |
# max_new_tokens = max_new_tokens, | |
# repetition_penalty = repetition_penalty, | |
# stop_words = ['<|im_end|>'], | |
# ) | |
self.agent = MindSearchAgent( | |
llm = llm, | |
protocol = MindSearchProtocol( | |
meta_prompt = datetime.now().strftime('The current date is %Y-%m-%d.'), | |
interpreter_prompt = GRAPH_PROMPT_CN if lang == 'cn' else GRAPH_PROMPT_EN, | |
response_prompt = FINAL_RESPONSE_CN if lang == 'cn' else FINAL_RESPONSE_EN | |
), | |
searcher_cfg=dict( | |
llm = llm, | |
plugin_executor = ActionExecutor( | |
BingBrowser(searcher_type='DuckDuckGoSearch', topk=6) | |
), | |
protocol = MindSearchProtocol( | |
meta_prompt=datetime.now().strftime('The current date is %Y-%m-%d.'), | |
plugin_prompt=searcher_system_prompt_cn if lang == 'cn' else searcher_system_prompt_en, | |
), | |
template = dict( | |
input=searcher_input_template_cn if lang == 'cn' else searcher_input_template_en, | |
context=searcher_context_template_cn if lang == 'cn' else searcher_context_template_en) | |
), | |
max_turn = max_turn | |
) | |
def __call__( | |
self, | |
queries: List[str] | |
): | |
results = [] | |
for qid, query in enumerate(queries): | |
result = None | |
for agent_return in self.agent.stream_chat(query): | |
if isinstance(agent_return, AgentReturn): | |
if agent_return.state == AgentStatusCode.END: | |
result = agent_return.response | |
assert result is not None | |
with open('temp/search_result_{}.txt'.format(qid), 'w', encoding='utf-8') as wf: | |
wf.write(result) | |
results.append(result) | |
# for qid, query in enumerate(queries): | |
# result = None | |
# agent_return = self.agent.generate(query) | |
# result = agent_return.response | |
# assert result is not None | |
# with open('temp/search_result_{}.txt'.format(qid), 'w', encoding='utf-8') as wf: | |
# wf.write(result) | |
# results.append(result) | |
return results | |
class VisionSearchAssistant: | |
""" | |
Vision Search Assistant: Empower Vision-Language Models as Multimodal Search Engines | |
This class implements all variants of Vision Search Assistant: | |
* search_model: Vision Search Assistant use this model for dealing with the search process, | |
it corresponds to the $\mathcal{F}_{llm}(cdot)$ in the paper. You can choose the model | |
according to your preference. | |
* ground_model: The vision foundation model used in the open-vocab detection process, | |
it's relevant to the specific contents of the classes in the image. | |
* vlm_model: The main vision-language model we used in our paper is LLaVA-1.6 baseline, | |
It can be further improved by using advanced models. And it corresponds to | |
the $\mathcal{F}_{vlm}(cdot)$ in the paper. | |
""" | |
def __init__( | |
self, | |
search_model: str = "internlm/internlm2_5-1_8b-chat", | |
ground_model: str = "IDEA-Research/grounding-dino-tiny", | |
ground_device: str = "cuda:1", | |
vlm_model: str = "liuhaotian/llava-v1.6-vicuna-7b", | |
vlm_device: str = "cuda:2", | |
vlm_load_4bit: bool = True, | |
vlm_load_8bit: bool = False, | |
): | |
self.search_model = search_model | |
self.ground_model = ground_model | |
self.ground_device = ground_device | |
self.vlm_model = vlm_model | |
self.vlm_device = vlm_device | |
self.vlm_load_4bit = vlm_load_4bit | |
self.vlm_load_8bit = vlm_load_8bit | |
self.use_correlate = True | |
self.searcher = WebSearcher( | |
model_path = self.search_model, | |
lang = 'en' | |
) | |
self.grounder = VisualGrounder( | |
model_path = self.ground_model, | |
device = self.ground_device, | |
) | |
self.vlm = VLM( | |
model_path = self.vlm_model, | |
device = self.vlm_device, | |
load_4bit = self.vlm_load_4bit, | |
load_8bit = self.vlm_load_8bit | |
) | |
def app_run( | |
self, | |
image: Union[str, Image.Image, np.ndarray], | |
text: str, | |
ground_classes: List[str] = COCO_CLASSES | |
): | |
# Create and clear the temporary directory. | |
if not os.access('temp', os.F_OK): | |
os.makedirs('temp') | |
for file in os.listdir('temp'): | |
os.remove(os.path.join('temp', file)) | |
with open('temp/text.txt', 'w', encoding='utf-8') as wf: | |
wf.write(text) | |
# Load Image | |
if isinstance(image, str): | |
in_image = Image.open(image) | |
elif isinstance(image, Image.Image): | |
in_image = image | |
elif isinstance(image, np.ndarray): | |
in_image = Image.fromarray(image.astype(np.uint8)) | |
else: | |
raise Exception('Unsupported input image format.') | |
# Visual Grounding | |
bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes) | |
yield out_image, 'ground' | |
det_images = [] | |
for bid, bbox in enumerate(bboxes): | |
crop_box = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])) | |
det_image = in_image.crop(crop_box) | |
det_image.save('temp/debug_bbox_image_{}.jpg'.format(bid)) | |
det_images.append(det_image) | |
if len(det_images) == 0: # No object detected, use the full image. | |
det_images.append(in_image) | |
labels.append('image') | |
# Visual Captioning | |
captions = [] | |
for det_image, label in zip(det_images, labels): | |
inp = get_caption_prompt(label, text) | |
caption = self.vlm(det_image, inp) | |
captions.append(caption) | |
for cid, caption in enumerate(captions): | |
with open('temp/caption_{}.txt'.format(cid), 'w', encoding='utf-8') as wf: | |
wf.write(caption) | |
# Visual Correlation | |
if len(captions) >= 2 and self.use_correlate: | |
queries = [] | |
for mid, det_image in enumerate(det_images): | |
caption = captions[mid] | |
other_captions = [] | |
for cid in range(len(captions)): | |
if cid == mid: | |
continue | |
other_captions.append(captions[cid]) | |
inp = get_correlate_prompt(caption, other_captions) | |
query = self.vlm(det_image, inp) | |
queries.append(query) | |
else: | |
queries = captions | |
for qid, query in enumerate(queries): | |
with open('temp/query_{}.txt'.format(qid), 'w', encoding='utf-8') as wf: | |
wf.write(query) | |
yield queries, 'query' | |
queries = [text + " " + query for query in queries] | |
# Web Searching | |
contexts = self.searcher(queries) | |
yield contexts, 'search' | |
# QA | |
TOKEN_LIMIT = 3500 | |
max_length_per_context = TOKEN_LIMIT // len(contexts) | |
for cid, context in enumerate(contexts): | |
contexts[cid] = (queries[cid] + context)[:max_length_per_context] | |
inp = get_qa_prompt(text, contexts) | |
answer = self.vlm(in_image, inp) | |
with open('temp/answer.txt', 'w', encoding='utf-8') as wf: | |
wf.write(answer) | |
print(answer) | |
yield answer, 'answer' | |