macadeliccc commited on
Commit
4911f6e
·
1 Parent(s): f7da2ba
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -4,20 +4,16 @@ import torch
4
  from gradio import State
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
 
7
 
8
  # Load the tokenizer and model
9
- tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
10
- model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
11
 
12
- # Ensure the model is in evaluation mode
13
- model.eval()
14
-
15
- # Move model to GPU if available
16
- if torch.cuda.is_available():
17
- model = model.to("cuda").half()
18
 
19
  @spaces.GPU
20
  def generate_response(user_input, chat_history):
 
21
  prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
22
  if chat_history:
23
  prompt = chat_history + prompt
@@ -28,7 +24,7 @@ def generate_response(user_input, chat_history):
28
 
29
  with torch.no_grad():
30
  # Generate the model's output
31
- output = model.generate(**inputs, max_length=1024, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
32
  response = tokenizer.decode(output[0], skip_special_tokens=True)
33
 
34
  # Update chat history
 
4
  from gradio import State
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
9
  # Load the tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
11
+ model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
12
 
 
 
 
 
 
 
13
 
14
  @spaces.GPU
15
  def generate_response(user_input, chat_history):
16
+
17
  prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
18
  if chat_history:
19
  prompt = chat_history + prompt
 
24
 
25
  with torch.no_grad():
26
  # Generate the model's output
27
+ output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
28
  response = tokenizer.decode(output[0], skip_special_tokens=True)
29
 
30
  # Update chat history