Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import torch | |
torch.jit.script = lambda f: f | |
import shlex | |
import spaces | |
import gradio as gr | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
import hashlib | |
import os | |
from transformers import AutoModel, AutoProcessor | |
import sys | |
import subprocess | |
from PIL import Image | |
import time | |
# install packages for mamba | |
def install(): | |
print("Install personal packages", flush=True) | |
subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) | |
install() | |
from cobra import load | |
vlm = load("cobra+3b") | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
DTYPE = torch.bfloat16 | |
else: | |
DEVICE = "cpu" | |
DTYPE = torch.float32 | |
vlm.to(DEVICE, dtype=DTYPE) | |
prompt_builder = vlm.get_prompt_builder() | |
def bot_streaming(message, history, temperature, top_k, max_new_tokens): | |
streamer = TextIteratorStreamer(vlm.llm_backbone.tokenizer, skip_special_tokens=True) | |
if len(history) == 0: | |
prompt_builder.prompt, prompt_builder.turn_count = "", 0 | |
image = None | |
if message["files"]: | |
image = message["files"][-1]["path"] | |
else: | |
# if there's no image uploaded for this turn, look for images in the past turns | |
# kept inside tuples, take the last one | |
for hist in history: | |
if type(hist[0])==tuple: | |
image = hist[0][0] | |
if image is not None: | |
image = Image.open(image).convert("RGB") | |
prompt_builder.add_turn(role="human", message=message['text']) | |
prompt_text = prompt_builder.get_prompt() | |
generation_kwargs = { | |
"image": image, | |
"prompt_text": prompt_text, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"use_cache": True, | |
"temperature": temperature, | |
"do_sample": True, | |
"top_k": top_k, | |
} | |
# Generate from the VLM | |
thread = Thread(target=vlm.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
output_started = False | |
for new_text in streamer: | |
if not output_started: | |
if "<|assistant|>\n" in new_text: | |
output_started = True | |
continue | |
buffer += new_text | |
if len(buffer) > 1: | |
yield buffer | |
prompt_builder.add_turn(role="gpt", message=buffer) | |
return buffer | |
demo = gr.ChatInterface(fn=bot_streaming, | |
additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), | |
gr.Slider(1, 3, value=1, step=1, label="Top k"), | |
gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], | |
title="Cobra", | |
description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it. Clear the history before asking questions related to new images", | |
stop_btn="Stop Generation", multimodal=True) | |
demo.launch(debug=True) |