zR
update
040e362
raw
history blame
4.52 kB
from threading import Thread
import requests
from io import BytesIO
from PIL import Image
import re
import gradio as gr
import torch
import spaces
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoImageProcessor,
TextIteratorStreamer,
)
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-edge-v-5b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("THUDM/glm-edge-v-5b", trust_remote_code=True, device_map="auto").eval()
processor = AutoImageProcessor.from_pretrained("THUDM/glm-edge-v-5b", trust_remote_code=True, device_map="auto")
def get_image(image):
if is_url(image):
response = requests.get(image)
return Image.open(BytesIO(response.content)).convert("RGB")
elif image:
return Image.open(image).convert("RGB")
def is_url(s):
if re.match(r'^(?:http|ftp)s?://', s):
return True
return False
def preprocess_messages(history, image):
messages = []
pixel_values = None
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not messages:
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
break
if user_msg:
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if model_msg:
messages.append({"role": "assistant", "content": [{"type": "text", "text": model_msg}]})
if image:
messages[-1]['content'].append({"type": "image"})
try:
image_input = get_image(image)
pixel_values = torch.tensor(
processor(image_input).pixel_values).to(model.device)
except:
print("Invalid image path. Continuing with text conversation.")
return messages, pixel_values
@spaces.GPU()
def predict(history, max_length, top_p, temperature, image=None):
messages, pixel_values = preprocess_messages(history, image)
model_inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
)
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": model_inputs["input_ids"].to(model.device),
"attention_mask": model_inputs["attention_mask"].to(model.device),
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": 1.2,
"eos_token_id": [59246, 59253, 59255],
}
if image and isinstance(pixel_values, torch.Tensor):
generate_kwargs['pixel_values'] = pixel_values
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token:
history[-1][1] += new_token
yield history
def main():
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">GLM-Edge-v Gradio Demo</h1>""")
# Top row: Chatbot and Image upload
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot()
with gr.Column(scale=1):
image_input = gr.Image(label="Upload an Image", type="filepath")
# Bottom row: System prompt, user input, and controls
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(show_label=True, placeholder="Input...", label="User Input")
submitBtn = gr.Button("Submit")
emptyBtn = gr.Button("Clear History")
with gr.Column(scale=1):
max_length = gr.Slider(0, 8192, value=4096, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
# Define functions for button actions
def user(query, history):
return "", history + [[query, ""]]
# Button actions and callbacks
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature, image_input], chatbot
)
emptyBtn.click(lambda: (None, None), None, [chatbot], queue=False)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()