File size: 4,780 Bytes
ab382f0
 
 
 
 
 
 
 
 
e7dfb54
ab382f0
063dc40
ab382f0
 
489e9b6
ab382f0
65610f2
 
f0d4bd1
65610f2
3b4a08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab382f0
 
 
2ae46d7
3b4a08d
2ae46d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b4a08d
063dc40
2ae46d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b4a08d
2ae46d7
3b4a08d
ab382f0
 
 
 
 
56c0ca0
ab382f0
 
 
 
 
 
 
3b4a08d
80d5294
 
 
 
f2fdba7
80d5294
 
063dc40
489e9b6
65610f2
80d5294
489e9b6
 
80d5294
 
063dc40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
import subprocess
from threading import Thread

import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

subprocess.run('pip install --upgrade flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

MODEL_ID = "AtAndDev/marco-qwq-7B"
CHAT_TEMPLATE = "ChatML"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 48000

COLOR = "blue"
EMOJI = "🤖"
DESCRIPTION = """Finetuned Marco-O1 on ~60k QwQ responses to get this model. It can mimic "system-2" thinking in a small form factor.
More r&d coming soon!"""

latex_delimiters_set = [{
        "left": "\\(",
        "right": "\\)",
        "display": False 
    }, {
        "left": "\\begin{equation}",
        "right": "\\end{equation}",
        "display": True 
    }, {
        "left": "\\begin{align}",
        "right": "\\end{align}",
        "display": True
    }, {
        "left": "\\begin{alignat}",
        "right": "\\end{alignat}",
        "display": True
    }, {
        "left": "\\begin{gather}",
        "right": "\\end{gather}",
        "display": True
    }, {
        "left": "\\begin{CD}",
        "right": "\\end{CD}",
        "display": True
    }, {
        "left": "\\[",
        "right": "\\]",
        "display": True
    }]


@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history with a given chat template
    if CHAT_TEMPLATE == "Auto":
        stop_tokens = [tokenizer.eos_token_id]
        instruction = system_prompt + "\n\n"
        for user, assistant in history:
            instruction += f"User: {user}\nAssistant: {assistant}\n"
        instruction += f"User: {message}\nAssistant:"
    elif CHAT_TEMPLATE == "ChatML":
        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
        for user, assistant in history:
            instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
        instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
    elif CHAT_TEMPLATE == "Mistral Instruct":
        stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
        instruction = f'<s>[INST] {system_prompt}\n'
        for user, assistant in history:
            instruction += f'{user} [/INST] {assistant}</s>[INST]'
        instruction += f' {message} [/INST]'
    else:
        raise Exception("Incorrect chat template, select 'Auto', 'ChatML' or 'Mistral Instruct'")
    print(instruction)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]
        attention_mask = attention_mask[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        input_ids=input_ids.to(device),
        attention_mask=attention_mask.to(device),
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        if new_token in stop_tokens:
            break
        yield "".join(outputs)

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("AIDC-AI/Marco-o1", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)

# Create Gradio interface
gr.ChatInterface(
    predict,
    title=EMOJI + " " + MODEL_NAME,
    description=DESCRIPTION,

    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
    additional_inputs=[
        gr.Textbox("You are a helpful assistant.", label="System prompt"),
        gr.Slider(0, 1, .2, label="Temperature"),
        gr.Slider(128, 32000, 32000, label="Max new tokens"),
        gr.Slider(1, 80, 40, label="Top K sampling"),
        gr.Slider(0, 2, 1.2, label="Repetition penalty"),
        gr.Slider(0, 1, .9, label="Top P sampling"),
    ],
    theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch()