Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,007 Bytes
5956319 74995d7 5956319 a505b42 f316cfc 9c1d271 b1f30ac ebcc5ea f316cfc a9a8422 4911f6e a9a8422 a1908d6 f316cfc b1f30ac a9a8422 b1f30ac f316cfc b1f30ac f316cfc b1f30ac f316cfc f08d3be b1f30ac f08d3be b1f30ac f316cfc 0e4adfe a9a8422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import spaces
import gradio as gr
import torch
from gradio import State
from transformers import AutoTokenizer, AutoModelForCausalLM
# Select the device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
model.eval() # Set the model to evaluation mode
@spaces.GPU
def generate_response(user_input, chat_history):
try:
prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
if chat_history:
prompt = chat_history[-1024:] + prompt # Keep last 1024 tokens of history
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
inputs = {k: v.to(device) for k, v in inputs.items()} # Move input tensors to the same device as the model
with torch.no_grad():
output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[0], skip_special_tokens=True)
new_history = chat_history + prompt + response
return response, new_history[-1024:] # Return last 1024 tokens of history
except Exception as e:
return f"Error occurred: {e}", chat_history
# Gradio Interface
def clear_chat():
return "", ""
with gr.Blocks(gr.themes.Soft()) as app:
with gr.Row():
chatbot = gr.Chatbot()
with gr.Row():
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
send = gr.Button("Send")
clear = gr.Button("Clear")
chat_history = gr.State() # Holds the chat history
send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
clear.click(clear_chat, outputs=[chatbot, chat_history])
app.launch()
|