from __future__ import annotations import torch torch.jit.script = lambda f: f import shlex import spaces import gradio as gr from threading import Thread from transformers import TextIteratorStreamer import hashlib import os from transformers import AutoModel, AutoProcessor import sys import subprocess from PIL import Image import time # install packages for mamba def install(): print("Install personal packages", flush=True) subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) install() from cobra import load vlm = load("cobra+3b") if torch.cuda.is_available(): DEVICE = "cuda" DTYPE = torch.bfloat16 else: DEVICE = "cpu" DTYPE = torch.float32 vlm.to(DEVICE, dtype=DTYPE) prompt_builder = vlm.get_prompt_builder() @spaces.GPU def bot_streaming(message, history, temperature, top_k, max_new_tokens): if len(history) == 0: prompt_builder.prompt, prompt_builder.turn_count = "", 0 image = None if message["files"]: image = message["files"][-1]["path"] else: # if there's no image uploaded for this turn, look for images in the past turns # kept inside tuples, take the last one for hist in history: if type(hist[0])==tuple: image = hist[0][0] if image is not None: image = Image.open(image).convert("RGB") prompt_builder.add_turn(role="human", message=message['text']) prompt_text = prompt_builder.get_prompt() # Generate from the VLM with torch.no_grad(): generated_text = vlm.generate( image, prompt_text, use_cache=True, do_sample=True, temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens, ) prompt_builder.add_turn(role="gpt", message=generated_text) time.sleep(0.04) yield generated_text demo = gr.ChatInterface(fn=bot_streaming, additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), gr.Slider(1, 3, value=1, step=1, label="Top k"), gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], title="Cobra", description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.", stop_btn="Stop Generation", multimodal=True) demo.launch(debug=True)