import gradio as gr import spaces import torch import subprocess import sys from threading import Thread from transformers import TextIteratorStreamer # Install required packages subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "transformers", "sentencepiece", "torch"]) subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) from transformers import OlmoeForCausalLM, AutoTokenizer model_name = "allenai/OLMoE-1B-7B-0924-Instruct" # Wrap model loading in a try-except block to handle potential errors try: DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = OlmoeForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, # Using float16 for lower precision low_cpu_mem_usage=True, device_map="auto", _attn_implementation="flash_attention_2" # Enable Flash Attention 2 ).to(DEVICE) model.gradient_checkpointing_enable() # Enable gradient checkpointing tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: print(f"Error loading model: {e}") model = None tokenizer = None system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy " "who is stuck inside a step function machine and remembers and counts everything he says " "while always answering questions in full first principles analysis type of thinking " "without using any analogies and always showing full working code or output in his answers.") @spaces.GPU def generate_response(message, history, temperature, max_new_tokens): if model is None or tokenizer is None: yield "Model or tokenizer not loaded properly. Please check the logs." return messages = [{"role": "system", "content": system_prompt}] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE) try: streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, eos_token_id=tokenizer.eos_token_id, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial_message = "" for new_text in streamer: partial_message += new_text yield partial_message.strip() except Exception as e: yield f"An error occurred: {str(e)}" css = """ #output { height: 1000px; overflow: auto; border: 2px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2 and Streaming!)") chatbot = gr.Chatbot(elem_id="output") msg = gr.Textbox(label="Meow") with gr.Row(): temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") max_new_tokens = gr.Slider(minimum=50, maximum=8000, value=2000, step=50, label="Max New Tokens") clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, None]] def bot(history, temp, max_tokens): user_message = history[-1][0] bot_message = "" for token in generate_response(user_message, history[:-1], temp, max_tokens): bot_message = token history[-1][1] = bot_message yield history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, temperature, max_new_tokens], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.queue(api_open=True, max_size=10) # Limiting queue size demo.launch(debug=True, show_api=True, share=False) # Disabled sharing for security