Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import OlmoeForCausalLM, AutoTokenizer | |
import torch | |
import os | |
from gradio_client import Client | |
# Initialize ZeroGPU | |
os.environ["ZEROGPU"] = "1" | |
# Set the device to GPU if available, otherwise fallback to ZeroGPU | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the model and tokenizer | |
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924").to(DEVICE) | |
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924") | |
# Define the system prompt | |
system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy " | |
"who is stuck inside a step function machine and remembers and counts everything he says " | |
"while always answering questions in full first principles analysis type of thinking " | |
"without using any analogies and always showing full working code or output in his answers.") | |
# Define a function for generating text | |
def generate_text(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
out = model.generate(**inputs, max_length=64) | |
return tokenizer.decode(out[0]) | |
# Function to set client for session | |
def set_client_for_session(request: gr.Request): | |
x_ip_token = request.headers['x-ip-token'] | |
return Client("gradio/text-to-image", headers={"X-IP-Token": x_ip_token}) | |
# Set up the Gradio chat interface | |
with gr.Blocks() as demo: | |
client = gr.State() | |
iface = gr.ChatInterface(fn=generate_text, system_prompt=system_prompt) | |
demo.load(set_client_for_session, None, client) | |
iface.launch() | |