Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
-
from transformers import
|
4 |
import torch
|
5 |
import subprocess
|
6 |
import sys
|
7 |
|
8 |
-
# Force install the
|
9 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "
|
10 |
|
11 |
model_name = "allenai/OLMoE-1B-7B-0924"
|
12 |
|
13 |
# Wrap model loading in a try-except block to handle potential errors
|
14 |
try:
|
15 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
-
model =
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
except Exception as e:
|
19 |
print(f"Error loading model: {e}")
|
@@ -58,7 +58,7 @@ css = """
|
|
58 |
with gr.Blocks(css=css) as demo:
|
59 |
gr.Markdown("# Nisten's Karpathy Chatbot with OSS olMoE")
|
60 |
chatbot = gr.Chatbot(elem_id="output")
|
61 |
-
msg = gr.Textbox(label="Your
|
62 |
with gr.Row():
|
63 |
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
|
64 |
max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")
|
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
import torch
|
5 |
import subprocess
|
6 |
import sys
|
7 |
|
8 |
+
# Force install the latest transformers version and flash attention
|
9 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "transformers", "flash-attn"])
|
10 |
|
11 |
model_name = "allenai/OLMoE-1B-7B-0924"
|
12 |
|
13 |
# Wrap model loading in a try-except block to handle potential errors
|
14 |
try:
|
15 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").to(DEVICE)
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
except Exception as e:
|
19 |
print(f"Error loading model: {e}")
|
|
|
58 |
with gr.Blocks(css=css) as demo:
|
59 |
gr.Markdown("# Nisten's Karpathy Chatbot with OSS olMoE")
|
60 |
chatbot = gr.Chatbot(elem_id="output")
|
61 |
+
msg = gr.Textbox(label="Your prompt")
|
62 |
with gr.Row():
|
63 |
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
|
64 |
max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")
|