File size: 4,523 Bytes
ad14d34
 
 
 
 
 
 
d42a4df
ad14d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d42a4df
ad14d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3584ef2
ad14d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302a2aa
ad14d34
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from threading import Thread
import requests
from io import BytesIO
from PIL import Image
import re
import gradio as gr
import torch
import spaces
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoImageProcessor,
    TextIteratorStreamer,
)

tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-edge-v-5b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("THUDM/glm-edge-v-5b", trust_remote_code=True, device_map="auto").eval()
processor = AutoImageProcessor.from_pretrained("THUDM/glm-edge-v-5b", trust_remote_code=True, device_map="auto")

def get_image(image):
    if is_url(image):
        response = requests.get(image)
        return Image.open(BytesIO(response.content)).convert("RGB")
    elif image:
        return Image.open(image).convert("RGB")

def is_url(s):
    if re.match(r'^(?:http|ftp)s?://', s):
        return True
    return False

def preprocess_messages(history, image):
    messages = []
    pixel_values = None

    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not messages:
            messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
            break
        if user_msg:
            messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
        if model_msg:
            messages.append({"role": "assistant", "content": [{"type": "text", "text": model_msg}]})
    if image:
        messages[-1]['content'].append({"type": "image"})
        try:
            image_input = get_image(image)
            
            pixel_values = torch.tensor(
                processor(image_input).pixel_values).to(model.device)
        except:
            print("Invalid image path. Continuing with text conversation.")
    return messages, pixel_values

@spaces.GPU()
def predict(history, max_length, top_p, temperature, image=None):
    messages, pixel_values = preprocess_messages(history, image)

    model_inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
    )
    
    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": model_inputs["input_ids"].to(model.device),
        "attention_mask": model_inputs["attention_mask"].to(model.device),
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "repetition_penalty": 1.2,
        "eos_token_id": [59246, 59253, 59255],

    }
    if image and isinstance(pixel_values, torch.Tensor):
        generate_kwargs['pixel_values'] = pixel_values
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    for new_token in streamer:
        if new_token:
            history[-1][1] += new_token
        yield history

def main():
    with gr.Blocks() as demo:
        gr.HTML("""<h1 align="center">GLM-Edge-v Gradio Demo</h1>""")

        # Top row: Chatbot and Image upload
        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot()
            with gr.Column(scale=1):
                image_input = gr.Image(label="Upload an Image", type="filepath")

        # Bottom row: System prompt, user input, and controls
        with gr.Row():
            with gr.Column(scale=2):
                user_input = gr.Textbox(show_label=True, placeholder="Input...", label="User Input")
                submitBtn = gr.Button("Submit")
                emptyBtn = gr.Button("Clear History")
            with gr.Column(scale=1):
                max_length = gr.Slider(0, 8192, value=4096, step=1.0, label="Maximum length", interactive=True)
                top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
                temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)

        # Define functions for button actions
        def user(query, history):
            return "", history + [[query, ""]]
        
        # Button actions and callbacks
        submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
            predict, [chatbot, max_length, top_p, temperature, image_input], chatbot
        )
        emptyBtn.click(lambda: (None, None), None, [chatbot], queue=False)

    demo.queue()
    demo.launch()


if __name__ == "__main__":
    main()