Spaces:
Runtime error
Runtime error
File size: 3,176 Bytes
746855d a83e4f5 3e802d7 746855d 5cc230a c660703 5cc230a 746855d 7fcce9b 746855d d9f9e4d 746855d 1ff6ff4 b2f8664 966597c 1ff6ff4 b2f8664 99c6873 746855d 99c6873 d9a4193 746855d b2f8664 746855d 1ff6ff4 6b6fe1e 1ff6ff4 b2f8664 a15ba40 0d355d4 1ff6ff4 00e38ee 0d355d4 b2f8664 2ecd703 9ca6512 746855d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from __future__ import annotations
import torch
torch.jit.script = lambda f: f
import shlex
import spaces
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os
from transformers import AutoModel, AutoProcessor
import sys
import subprocess
from PIL import Image
import time
# install packages for mamba
def install():
print("Install personal packages", flush=True)
subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl"))
install()
from cobra import load
vlm = load("cobra+3b")
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.bfloat16
else:
DEVICE = "cpu"
DTYPE = torch.float32
vlm.to(DEVICE, dtype=DTYPE)
prompt_builder = vlm.get_prompt_builder()
@spaces.GPU(duration=20)
def bot_streaming(message, history, temperature, top_k, max_new_tokens):
streamer = TextIteratorStreamer(vlm.llm_backbone.tokenizer, skip_special_tokens=True)
if len(history) == 0:
prompt_builder.prompt, prompt_builder.turn_count = "", 0
image = None
if message["files"]:
image = message["files"][-1]["path"]
else:
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0])==tuple:
image = hist[0][0]
if image is not None:
image = Image.open(image).convert("RGB")
prompt_builder.add_turn(role="human", message=message['text'])
prompt_text = prompt_builder.get_prompt()
generation_kwargs = {
"image": image,
"prompt_text": prompt_text,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"use_cache": True,
"temperature": temperature,
"do_sample": True,
"top_k": top_k,
}
# Generate from the VLM
thread = Thread(target=vlm.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
output_started = False
for new_text in streamer:
if not output_started:
if "<|assistant|>\n" in new_text:
output_started = True
continue
buffer += new_text
if len(buffer) > 1:
yield buffer
prompt_builder.add_turn(role="gpt", message=buffer)
return buffer
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"),
gr.Slider(1, 3, value=1, step=1, label="Top k"),
gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")],
title="Cobra",
description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it. Clear the history before asking questions related to new images",
stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True) |