macadeliccc commited on
Commit
b1f30ac
·
1 Parent(s): 4911f6e
Files changed (1) hide show
  1. app.py +20 -23
app.py CHANGED
@@ -4,52 +4,49 @@ import torch
4
  from gradio import State
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
8
 
9
  # Load the tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
11
  model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
12
 
13
-
14
  @spaces.GPU
15
  def generate_response(user_input, chat_history):
16
-
17
- prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
18
- if chat_history:
19
- prompt = chat_history + prompt
20
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
21
-
22
- # Ensure all tensors are moved to the model's device
23
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
24
 
25
- with torch.no_grad():
26
- # Generate the model's output
27
- output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
28
- response = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
29
 
30
- # Update chat history
31
- new_history = prompt + response
32
- return response, new_history
33
 
34
  # Gradio Interface
35
  def clear_chat():
36
  return "", ""
37
 
38
-
39
  with gr.Blocks(gr.themes.Soft()) as app:
40
-
41
  with gr.Row():
42
  chatbot = gr.Chatbot()
43
-
44
  with gr.Row():
45
  user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
46
  send = gr.Button("Send")
47
  clear = gr.Button("Clear")
48
-
49
-
50
  chat_history = gr.State() # Holds the chat history
51
 
52
  send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
53
  clear.click(clear_chat, outputs=[chatbot, chat_history])
54
 
55
- app.launch()
 
4
  from gradio import State
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ # Select the device (GPU if available, else CPU)
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # Load the tokenizer and model
11
  tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
12
  model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
13
 
 
14
  @spaces.GPU
15
  def generate_response(user_input, chat_history):
16
+ try:
17
+ prompt = "GPT4 Correct User: " + user_input + "<|end_of_turn|>" + "GPT4 Correct Assistant: "
18
+ if chat_history:
19
+ prompt = chat_history[-1024:] + prompt # Keep last 1024 tokens of history
20
+
21
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
22
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Move input tensors to the same device as the model
 
23
 
24
+ with torch.no_grad():
25
+ output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
26
+
27
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
28
+ new_history = chat_history + prompt + response
29
+ return response, new_history[-1024:] # Return last 1024 tokens of history
30
 
31
+ except Exception as e:
32
+ return f"Error occurred: {e}", chat_history
 
33
 
34
  # Gradio Interface
35
  def clear_chat():
36
  return "", ""
37
 
 
38
  with gr.Blocks(gr.themes.Soft()) as app:
 
39
  with gr.Row():
40
  chatbot = gr.Chatbot()
41
+
42
  with gr.Row():
43
  user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
44
  send = gr.Button("Send")
45
  clear = gr.Button("Clear")
46
+
 
47
  chat_history = gr.State() # Holds the chat history
48
 
49
  send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
50
  clear.click(clear_chat, outputs=[chatbot, chat_history])
51
 
52
+ app.launch()