File size: 4,379 Bytes
2f4b832
e659cfe
9f7cb9a
2b0dd1e
b3ca2da
deeaafe
24b2580
aab0c47
e203e91
247d769
e9acdad
720352d
24b2580
2f4b832
32720ee
9f7cb9a
3802faf
 
0ff1cd2
3d92619
9ca55ad
 
bc79f7e
9ca55ad
e203e91
24b2580
32720ee
24b2580
0ff1cd2
3802faf
 
 
 
e659cfe
9f7cb9a
 
 
 
 
e659cfe
24b2580
3802faf
24b2580
e203e91
 
32720ee
 
 
 
 
24b2580
32720ee
 
0cb4dc1
e9acdad
deeaafe
 
 
 
 
 
 
 
 
 
 
 
 
24b2580
deeaafe
24b2580
 
 
e9acdad
023bf24
0ff1cd2
 
 
023bf24
0ff1cd2
023bf24
0ff1cd2
 
b8261fb
0ff1cd2
24b2580
0ff1cd2
159c2ce
0ff1cd2
 
bc79f7e
0cb4dc1
 
24b2580
 
 
 
 
 
 
 
 
 
0cb4dc1
 
0ff1cd2
0cb4dc1
e9acdad
0cb4dc1
3802faf
24b2580
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import spaces
import torch
import subprocess
import sys
from threading import Thread
from transformers import TextIteratorStreamer

# Install required packages
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "transformers", "sentencepiece", "torch"])
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

from transformers import OlmoeForCausalLM, AutoTokenizer

model_name = "allenai/OLMoE-1B-7B-0924-Instruct"

# Wrap model loading in a try-except block to handle potential errors
try:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model = OlmoeForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True, 
        torch_dtype=torch.bfloat16,  # Using float16 for lower precision
        low_cpu_mem_usage=True,
        device_map="auto",
        _attn_implementation="flash_attention_2"  # Enable Flash Attention 2
    ).to(DEVICE)
    model.gradient_checkpointing_enable()  # Enable gradient checkpointing
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    tokenizer = None

system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
                 "who is stuck inside a step function machine and remembers and counts everything he says "
                 "while always answering questions in full first principles analysis type of thinking "
                 "without using any analogies and always showing full working code or output in his answers.")

@spaces.GPU
def generate_response(message, history, temperature, max_new_tokens):
    if model is None or tokenizer is None:
        yield "Model or tokenizer not loaded properly. Please check the logs."
        return

    messages = [{"role": "system", "content": system_prompt}]
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})

    inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
    
    try:
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(
            inputs=inputs,
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            eos_token_id=tokenizer.eos_token_id,
        )

        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        partial_message = ""
        for new_text in streamer:
            partial_message += new_text
            yield partial_message.strip()

    except Exception as e:
        yield f"An error occurred: {str(e)}"

css = """
  #output {
    height: 1000px; 
    overflow: auto; 
    border: 2px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2 and Streaming!)")
    chatbot = gr.Chatbot(elem_id="output")
    msg = gr.Textbox(label="Meow")
    with gr.Row():
        temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        max_new_tokens = gr.Slider(minimum=50, maximum=8000, value=2000, step=50, label="Max New Tokens")
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, temp, max_tokens):
        user_message = history[-1][0]
        bot_message = ""
        for token in generate_response(user_message, history[:-1], temp, max_tokens):
            bot_message = token
            history[-1][1] = bot_message
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, temperature, max_new_tokens], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue(api_open=True, max_size=10)  # Limiting queue size
    demo.launch(debug=True, show_api=True, share=False)  # Disabled sharing for security