File size: 3,176 Bytes
746855d
 
a83e4f5
 
 
3e802d7
746855d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cc230a
 
 
c660703
 
5cc230a
 
746855d
7fcce9b
746855d
 
 
 
d9f9e4d
746855d
 
 
 
 
 
 
1ff6ff4
b2f8664
966597c
1ff6ff4
b2f8664
 
99c6873
 
 
746855d
 
 
 
 
 
 
 
 
99c6873
 
d9a4193
746855d
b2f8664
746855d
1ff6ff4
 
6b6fe1e
 
1ff6ff4
 
 
 
 
 
 
b2f8664
a15ba40
0d355d4
 
 
 
 
 
 
 
 
 
 
 
 
1ff6ff4
 
00e38ee
0d355d4
b2f8664
 
 
 
 
 
 
2ecd703
9ca6512
746855d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import annotations

import torch
torch.jit.script = lambda f: f

import shlex
import spaces

import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os

from transformers import AutoModel, AutoProcessor
import sys
import subprocess
from PIL import Image

import time

# install packages for mamba
def install():
    print("Install personal packages", flush=True)
    subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl"))

install()

from cobra import load
vlm = load("cobra+3b")

if torch.cuda.is_available():
    DEVICE = "cuda"
    DTYPE = torch.bfloat16
else:
    DEVICE = "cpu"
    DTYPE = torch.float32
vlm.to(DEVICE, dtype=DTYPE)

prompt_builder = vlm.get_prompt_builder()

@spaces.GPU(duration=20)
def bot_streaming(message, history, temperature, top_k, max_new_tokens):
    streamer = TextIteratorStreamer(vlm.llm_backbone.tokenizer, skip_special_tokens=True)
    
    if len(history) == 0:
        prompt_builder.prompt, prompt_builder.turn_count = "", 0

    image = None
    
    if message["files"]:
        image = message["files"][-1]["path"]
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0])==tuple:
                image = hist[0][0]
  
    if image is not None:
        image = Image.open(image).convert("RGB")

    
    prompt_builder.add_turn(role="human", message=message['text'])
    prompt_text = prompt_builder.get_prompt()

    generation_kwargs = {
        "image": image,
        "prompt_text": prompt_text,
        "streamer": streamer,
        "max_new_tokens": max_new_tokens,
        "use_cache": True,
        "temperature": temperature,
        "do_sample": True,
        "top_k": top_k,
    }
    
    # Generate from the VLM
    thread = Thread(target=vlm.generate, kwargs=generation_kwargs)
    thread.start()

    buffer = ""
    output_started = False
    for new_text in streamer:
        if not output_started:
            if "<|assistant|>\n" in new_text:
                output_started = True
            continue
        buffer += new_text
        if len(buffer) > 1:
            yield buffer

    prompt_builder.add_turn(role="gpt", message=buffer)

    return buffer
    
    
demo = gr.ChatInterface(fn=bot_streaming, 
                        additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"),
                                           gr.Slider(1, 3, value=1, step=1, label="Top k"),
                                           gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")],
                        title="Cobra", 
                        description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it. Clear the history before asking questions related to new images",
                        stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True)