from __future__ import annotations import torch torch.jit.script = lambda f: f import shlex import spaces import gradio as gr from threading import Thread from transformers import TextIteratorStreamer import hashlib import os from transformers import AutoModel, AutoProcessor import sys import subprocess from PIL import Image import time # install packages for mamba def install(): print("Install personal packages", flush=True) subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) install() from cobra import load vlm = load("cobra+3b") if torch.cuda.is_available(): DEVICE = "cuda" DTYPE = torch.bfloat16 else: DEVICE = "cpu" DTYPE = torch.float32 vlm.to(DEVICE, dtype=DTYPE) prompt_builder = vlm.get_prompt_builder() @spaces.GPU(duration=20) def bot_streaming(message, history, temperature, top_k, max_new_tokens): streamer = TextIteratorStreamer(vlm.llm_backbone.tokenizer, skip_special_tokens=True) if len(history) == 0: prompt_builder.prompt, prompt_builder.turn_count = "", 0 image = None if message["files"]: image = message["files"][-1]["path"] else: # if there's no image uploaded for this turn, look for images in the past turns # kept inside tuples, take the last one for hist in history: if type(hist[0])==tuple: image = hist[0][0] if image is not None: image = Image.open(image).convert("RGB") prompt_builder.add_turn(role="human", message=message['text']) prompt_text = prompt_builder.get_prompt() generation_kwargs = { "image": image, "prompt_text": prompt_text, "streamer": streamer, "max_new_tokens": max_new_tokens, "use_cache": True, "temperature": temperature, "do_sample": True, "top_k": top_k, } # Generate from the VLM thread = Thread(target=vlm.generate, kwargs=generation_kwargs) thread.start() buffer = "" output_started = False for new_text in streamer: if not output_started: if "<|assistant|>\n" in new_text: output_started = True continue buffer += new_text if len(buffer) > 1: yield buffer prompt_builder.add_turn(role="gpt", message=buffer) return buffer demo = gr.ChatInterface(fn=bot_streaming, additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), gr.Slider(1, 3, value=1, step=1, label="Top k"), gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], title="Cobra", description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it. Clear the history before asking questions related to new images", stop_btn="Stop Generation", multimodal=True) demo.launch(debug=True)