Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import spaces | |
import gradio as gr | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
import hashlib | |
import os | |
from transformers import AutoModel, AutoProcessor | |
import torch | |
import sys | |
import subprocess | |
from PIL import Image | |
from cobra import load | |
import time | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'packaging']) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'ninja']) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm']) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d']) | |
vlm = load("cobra+3b") | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
DTYPE = torch.bfloat16 | |
else: | |
DEVICE = "cpu" | |
DTYPE = torch.float32 | |
vlm.to(DEVICE, dtype=DTYPE) | |
prompt_builder = vlm.get_prompt_builder() | |
def bot_streaming(message, history, temperature, top_k, max_new_tokens): | |
if len(history) == 0: | |
prompt_builder.prompt, prompt_builder.turn_count = "", 0 | |
print(message) | |
if message["files"]: | |
image = message["files"][-1]["path"] | |
else: | |
# if there's no image uploaded for this turn, look for images in the past turns | |
# kept inside tuples, take the last one | |
for hist in history: | |
if type(hist[0])==tuple: | |
image = hist[0][0] | |
image = Image.open(image).convert("RGB") | |
prompt_builder.add_turn(role="human", message=message['text']) | |
prompt_text = prompt_builder.get_prompt() | |
# Generate from the VLM | |
with torch.no_grad(): | |
generated_text = vlm.generate( | |
image, | |
prompt_text, | |
cg=True, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
) | |
prompt_builder.add_turn(role="gpt", message=generated_text) | |
time.sleep(0.04) | |
yield generated_text | |
demo = gr.ChatInterface(fn=bot_streaming, | |
additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), | |
gr.Slider(1, 3, value=1, step=1, label="Top k"), | |
gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], | |
title="Cobra", | |
description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.", | |
stop_btn="Stop Generation", multimodal=True, | |
examples=[{"text": "Describe this image", "files":["./cobra.png"]}]) | |
demo.launch(debug=True) |