Zul001 commited on
Commit
0e5b461
·
verified ·
1 Parent(s): 92a04df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -28
app.py CHANGED
@@ -29,38 +29,49 @@ function refresh() {
29
 
30
  previous_sessions = []
31
 
32
- class ChatState():
33
- __START_TURN_USER__ = "<start_of_turn>user\n"
34
- __START_TURN_MODEL__ = "<start_of_turn>model\n"
35
- __END_TURN__ = "<end_of_turn>\n"
36
-
37
- def __init__(self, model, system=""):
38
  self.model = model
 
39
  self.system = system
40
  self.history = []
41
 
42
- def add_to_history_as_user(self, message):
43
- self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)
44
-
45
- def add_to_history_as_model(self, message):
46
- self.history.append(self.__START_TURN_MODEL__ + message)
47
-
48
- def get_history(self):
49
- return "".join([*self.history])
50
 
51
  def get_full_prompt(self):
52
- prompt = self.get_history() + self.__START_TURN_MODEL__
53
- if len(self.system) > 0:
54
- prompt = self.system + "\n" + prompt
 
 
 
55
  return prompt
56
 
57
  def send_message(self, message):
58
- self.add_to_history_as_user(message)
59
  prompt = self.get_full_prompt()
60
- response = self.model.generate(prompt, max_length=2048)
61
- result = response.replace(prompt, "") # Extract only the new response
62
- self.add_to_history_as_model(result)
63
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def post_process_output(prompt, result):
66
  answer = result.strip()
@@ -102,18 +113,15 @@ def inference(prompt):
102
  if reset_triggered:
103
  return "", ""
104
 
105
- chat_state.send_message(prompt) # Process the user's message
106
-
107
- # Post-process the output from the model
108
- formatted_output = post_process_output(chat_state.get_full_prompt(), chat_state.get_history())
109
 
110
  # Apply a bit of delay for a realistic response time
111
  time.sleep(1)
112
 
113
- result = formatted_output
114
- sessions = add_session(chat_state.get_history())
115
  return result, sessions
116
 
 
117
  with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
118
 
119
  gr.Markdown("<center><h1>HydroSense LLM Demo</h1></center>")
 
29
 
30
  previous_sessions = []
31
 
32
+ class ChatState:
33
+ def __init__(self, model, tokenizer, system=""):
 
 
 
 
34
  self.model = model
35
+ self.tokenizer = tokenizer
36
  self.system = system
37
  self.history = []
38
 
39
+ def add_to_history(self, role, message):
40
+ self.history.append({"role": role, "content": message})
 
 
 
 
 
 
41
 
42
  def get_full_prompt(self):
43
+ prompt = ""
44
+ if self.system:
45
+ prompt += f"System: {self.system}\n\n"
46
+ for message in self.history:
47
+ prompt += f"{message['role'].capitalize()}: {message['content']}\n"
48
+ prompt += "Model: "
49
  return prompt
50
 
51
  def send_message(self, message):
52
+ self.add_to_history("user", message)
53
  prompt = self.get_full_prompt()
54
+
55
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
56
+
57
+ with torch.no_grad():
58
+ outputs = self.model.generate(
59
+ **inputs,
60
+ max_new_tokens=512,
61
+ num_return_sequences=1,
62
+ do_sample=True,
63
+ temperature=0.7
64
+ )
65
+
66
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
+ response = response.replace(prompt, "").strip()
68
+
69
+ self.add_to_history("model", response)
70
+ return response
71
+
72
+ # Initialize the ChatState
73
+ chat_state = ChatState(model, tokenizer, system="You are a helpful AI assistant.")
74
+
75
 
76
  def post_process_output(prompt, result):
77
  answer = result.strip()
 
113
  if reset_triggered:
114
  return "", ""
115
 
116
+ result = chat_state.send_message(prompt)
 
 
 
117
 
118
  # Apply a bit of delay for a realistic response time
119
  time.sleep(1)
120
 
121
+ sessions = add_session(prompt)
 
122
  return result, sessions
123
 
124
+
125
  with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
126
 
127
  gr.Markdown("<center><h1>HydroSense LLM Demo</h1></center>")