Spaces:
Running
on
Zero
Running
on
Zero
macadeliccc
commited on
Commit
·
b1f30ac
1
Parent(s):
4911f6e
test
Browse files
app.py
CHANGED
@@ -4,52 +4,49 @@ import torch
|
|
4 |
from gradio import State
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
|
7 |
-
|
|
|
8 |
|
9 |
# Load the tokenizer and model
|
10 |
tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
|
11 |
model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
|
12 |
|
13 |
-
|
14 |
@spaces.GPU
|
15 |
def generate_response(user_input, chat_history):
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
return response, new_history
|
33 |
|
34 |
# Gradio Interface
|
35 |
def clear_chat():
|
36 |
return "", ""
|
37 |
|
38 |
-
|
39 |
with gr.Blocks(gr.themes.Soft()) as app:
|
40 |
-
|
41 |
with gr.Row():
|
42 |
chatbot = gr.Chatbot()
|
43 |
-
|
44 |
with gr.Row():
|
45 |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
|
46 |
send = gr.Button("Send")
|
47 |
clear = gr.Button("Clear")
|
48 |
-
|
49 |
-
|
50 |
chat_history = gr.State() # Holds the chat history
|
51 |
|
52 |
send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
|
53 |
clear.click(clear_chat, outputs=[chatbot, chat_history])
|
54 |
|
55 |
-
app.launch()
|
|
|
4 |
from gradio import State
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
|
7 |
+
# Select the device (GPU if available, else CPU)
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
|
10 |
# Load the tokenizer and model
|
11 |
tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
|
12 |
model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
|
13 |
|
|
|
14 |
@spaces.GPU
|
15 |
def generate_response(user_input, chat_history):
|
16 |
+
try:
|
17 |
+
prompt = "GPT4 Correct User: " + user_input + "<|end_of_turn|>" + "GPT4 Correct Assistant: "
|
18 |
+
if chat_history:
|
19 |
+
prompt = chat_history[-1024:] + prompt # Keep last 1024 tokens of history
|
20 |
+
|
21 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
22 |
+
inputs = {k: v.to(device) for k, v in inputs.items()} # Move input tensors to the same device as the model
|
|
|
23 |
|
24 |
+
with torch.no_grad():
|
25 |
+
output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
|
26 |
+
|
27 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
28 |
+
new_history = chat_history + prompt + response
|
29 |
+
return response, new_history[-1024:] # Return last 1024 tokens of history
|
30 |
|
31 |
+
except Exception as e:
|
32 |
+
return f"Error occurred: {e}", chat_history
|
|
|
33 |
|
34 |
# Gradio Interface
|
35 |
def clear_chat():
|
36 |
return "", ""
|
37 |
|
|
|
38 |
with gr.Blocks(gr.themes.Soft()) as app:
|
|
|
39 |
with gr.Row():
|
40 |
chatbot = gr.Chatbot()
|
41 |
+
|
42 |
with gr.Row():
|
43 |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
|
44 |
send = gr.Button("Send")
|
45 |
clear = gr.Button("Clear")
|
46 |
+
|
|
|
47 |
chat_history = gr.State() # Holds the chat history
|
48 |
|
49 |
send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
|
50 |
clear.click(clear_chat, outputs=[chatbot, chat_history])
|
51 |
|
52 |
+
app.launch()
|