File size: 2,007 Bytes
5956319
74995d7
5956319
a505b42
f316cfc
9c1d271
b1f30ac
 
ebcc5ea
f316cfc
a9a8422
4911f6e
a9a8422
a1908d6
f316cfc
 
b1f30ac
a9a8422
b1f30ac
 
 
 
 
f316cfc
b1f30ac
 
 
 
 
 
f316cfc
b1f30ac
 
f316cfc
 
 
 
 
 
 
f08d3be
b1f30ac
f08d3be
 
 
 
b1f30ac
f316cfc
 
 
 
0e4adfe
a9a8422
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import spaces
import gradio as gr
import torch
from gradio import State
from transformers import AutoTokenizer, AutoModelForCausalLM

# Select the device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
model.eval()  # Set the model to evaluation mode

@spaces.GPU
def generate_response(user_input, chat_history):
    try:
        prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
        if chat_history:
            prompt = chat_history[-1024:] + prompt  # Keep last 1024 tokens of history
        
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Move input tensors to the same device as the model

        with torch.no_grad():
            output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        new_history = chat_history + prompt + response
        return response, new_history[-1024:]  # Return last 1024 tokens of history

    except Exception as e:
        return f"Error occurred: {e}", chat_history

# Gradio Interface
def clear_chat():
    return "", ""

with gr.Blocks(gr.themes.Soft()) as app:
    with gr.Row():
        chatbot = gr.Chatbot()

    with gr.Row():
        user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
        send = gr.Button("Send")
        clear = gr.Button("Clear")

    chat_history = gr.State()  # Holds the chat history

    send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
    clear.click(clear_chat, outputs=[chatbot, chat_history])

app.launch()