cobra / app.py
han1997's picture
Update app.py
40c6f6d verified
raw
history blame
2.66 kB
from __future__ import annotations
import torch
torch.jit.script = lambda f: f
import shlex
import spaces
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os
from transformers import AutoModel, AutoProcessor
import sys
import subprocess
from PIL import Image
import time
# install packages for mamba
def install():
print("Install personal packages", flush=True)
subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl"))
install()
from cobra import load
vlm = load("cobra+3b")
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.bfloat16
else:
DEVICE = "cpu"
DTYPE = torch.float32
vlm.to(DEVICE, dtype=DTYPE)
prompt_builder = vlm.get_prompt_builder()
@spaces.GPU
def bot_streaming(message, history, temperature, top_k, max_new_tokens):
if len(history) == 0:
prompt_builder.prompt, prompt_builder.turn_count = "", 0
image = None
if message["files"]:
image = message["files"][-1]["path"]
else:
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0])==tuple:
image = hist[0][0]
if image is not None:
image = Image.open(image).convert("RGB")
prompt_builder.add_turn(role="human", message=message['text'])
prompt_text = prompt_builder.get_prompt()
# Generate from the VLM
with torch.no_grad():
generated_text = vlm.generate(
image,
prompt_text,
use_cache=True,
do_sample=True,
temperature=temperature,
top_k=top_k,
max_new_tokens=max_new_tokens,
)
prompt_builder.add_turn(role="gpt", message=generated_text)
time.sleep(0.04)
yield generated_text
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"),
gr.Slider(1, 3, value=1, step=1, label="Top k"),
gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")],
title="Cobra",
description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.",
stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True)