nisten commited on
Commit
be3574c
·
verified ·
1 Parent(s): 0ff1cd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -1,19 +1,19 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import OlmoeForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
7
 
8
- # Force install the specific transformers version from the GitHub PR
9
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
10
 
11
  model_name = "allenai/OLMoE-1B-7B-0924"
12
 
13
  # Wrap model loading in a try-except block to handle potential errors
14
  try:
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- model = OlmoeForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16).to(DEVICE)
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  except Exception as e:
19
  print(f"Error loading model: {e}")
@@ -58,7 +58,7 @@ css = """
58
  with gr.Blocks(css=css) as demo:
59
  gr.Markdown("# Nisten's Karpathy Chatbot with OSS olMoE")
60
  chatbot = gr.Chatbot(elem_id="output")
61
- msg = gr.Textbox(label="Your message")
62
  with gr.Row():
63
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
64
  max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
7
 
8
+ # Force install the latest transformers version and flash attention
9
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "transformers", "flash-attn"])
10
 
11
  model_name = "allenai/OLMoE-1B-7B-0924"
12
 
13
  # Wrap model loading in a try-except block to handle potential errors
14
  try:
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").to(DEVICE)
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  except Exception as e:
19
  print(f"Error loading model: {e}")
 
58
  with gr.Blocks(css=css) as demo:
59
  gr.Markdown("# Nisten's Karpathy Chatbot with OSS olMoE")
60
  chatbot = gr.Chatbot(elem_id="output")
61
+ msg = gr.Textbox(label="Your prompt")
62
  with gr.Row():
63
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
64
  max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")