han1997 commited on
Commit
746855d
·
verified ·
1 Parent(s): e533229

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import spaces
4
+
5
+ import gradio as gr
6
+ from threading import Thread
7
+ from transformers import TextIteratorStreamer
8
+ import hashlib
9
+ import os
10
+
11
+ from transformers import AutoModel, AutoProcessor
12
+ import torch
13
+ import sys
14
+ import subprocess
15
+ from PIL import Image
16
+
17
+ from cobra import load
18
+ import time
19
+
20
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm'])
21
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d'])
22
+
23
+ vlm = load("cobra+3b")
24
+
25
+ if torch.cuda.is_available():
26
+ DEVICE = "cuda"
27
+ DTYPE = torch.bfloat16
28
+ else:
29
+ DEVICE = "cpu"
30
+ DTYPE = torch.float32
31
+ vlm.to(DEVICE, dtype=DTYPE)
32
+
33
+ prompt_builder = vlm.get_prompt_builder()
34
+ system_prompt = prompt_builder.system_prompt
35
+
36
+ @spaces.GPU
37
+ def bot_streaming(message, history):
38
+ print(message)
39
+ if message["files"]:
40
+ image = message["files"][-1]["path"]
41
+ else:
42
+ # if there's no image uploaded for this turn, look for images in the past turns
43
+ # kept inside tuples, take the last one
44
+ for hist in history:
45
+ if type(hist[0])==tuple:
46
+ image = hist[0][0]
47
+
48
+ image = Image.open(image).convert("RGB")
49
+
50
+ prompt_builder.add_turn(role="human", message=message)
51
+ prompt_text = prompt_builder.get_prompt()
52
+
53
+ # Generate from the VLM
54
+ generated_text = vlm.generate(
55
+ image,
56
+ prompt_text,
57
+ cg=True,
58
+ do_sample=False,
59
+ temperature=1.0,
60
+ max_new_tokens=2048,
61
+ # do_sample=cfg.do_sample,
62
+ # temperature=cfg.temperature,
63
+ # max_new_tokens=cfg.max_new_tokens,
64
+ )
65
+ prompt_builder.add_turn(role="gpt", message=generated_text)
66
+
67
+ # streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
68
+ # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
69
+ # generation_kwargs = dict(image, prompt_text, cg=True, do_sample=cfg.do_sample, temperature=cfg.temperature, max_new_tokens=cfg.max_new_tokens)
70
+ generation_kwargs = dict(image, prompt_text, cg=True, do_sample=True, temperature=1.0, max_new_tokens=2048)
71
+
72
+ thread = Thread(target=vlm.generate, kwargs=generation_kwargs)
73
+ thread.start()
74
+
75
+ text_prompt =f"[INST] \n{message['text']} [/INST]"
76
+ print(generated_text)
77
+
78
+
79
+ buffer = ""
80
+ yield generated_text
81
+ # for new_text in streamer:
82
+
83
+ # buffer += new_text
84
+
85
+ # generated_text_without_prompt = buffer[len(text_prompt):]
86
+ # time.sleep(0.04)
87
+ # yield generated_text_without_prompt
88
+
89
+
90
+ demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Next", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
91
+ {"text": "How to make this pastry?", "files":["./baklava.png"]}],
92
+ description="Try [LLaVA Next](https://huggingface.co/papers/2310.03744) in this demo. Upload an image and start chatting about it, or simply try one of the examples below.",
93
+ stop_btn="Stop Generation", multimodal=True)
94
+ demo.launch(debug=True)