Zul001 commited on
Commit
7d3fe61
·
verified ·
1 Parent(s): 0e5b461

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -59
app.py CHANGED
@@ -1,13 +1,17 @@
 
1
  import gradio as gr
2
  import tensorflow.keras as keras
3
  import time
4
  import keras_nlp
5
  import os
6
 
 
7
  model_path = "Zul001/HydroSense_Gemma_Finetuned_Model"
8
  gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(f"hf://{model_path}")
9
 
 
10
  reset_triggered = False
 
11
  custom_css = """
12
  @import url('https://fonts.googleapis.com/css2?family=Edu+AU+VIC+WA+NT+Dots:[email protected]&family=Give+You+Glory&family=Sofia&family=Sunshiney&family=Vujahday+Script&display=swap');
13
  .gradio-container, .gradio-container * {
@@ -27,64 +31,27 @@ function refresh() {
27
  }
28
  """
29
 
30
- previous_sessions = []
31
-
32
- class ChatState:
33
- def __init__(self, model, tokenizer, system=""):
34
- self.model = model
35
- self.tokenizer = tokenizer
36
- self.system = system
37
- self.history = []
38
-
39
- def add_to_history(self, role, message):
40
- self.history.append({"role": role, "content": message})
41
-
42
- def get_full_prompt(self):
43
- prompt = ""
44
- if self.system:
45
- prompt += f"System: {self.system}\n\n"
46
- for message in self.history:
47
- prompt += f"{message['role'].capitalize()}: {message['content']}\n"
48
- prompt += "Model: "
49
- return prompt
50
-
51
- def send_message(self, message):
52
- self.add_to_history("user", message)
53
- prompt = self.get_full_prompt()
54
-
55
- inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
56
-
57
- with torch.no_grad():
58
- outputs = self.model.generate(
59
- **inputs,
60
- max_new_tokens=512,
61
- num_return_sequences=1,
62
- do_sample=True,
63
- temperature=0.7
64
- )
65
-
66
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
- response = response.replace(prompt, "").strip()
68
-
69
- self.add_to_history("model", response)
70
- return response
71
-
72
- # Initialize the ChatState
73
- chat_state = ChatState(model, tokenizer, system="You are a helpful AI assistant.")
74
 
 
75
 
76
  def post_process_output(prompt, result):
 
77
  answer = result.strip()
78
  if answer.startswith(prompt):
79
  answer = answer[len(prompt):].strip()
80
 
 
81
  answer = answer.lstrip(':')
 
 
82
  answer = answer.capitalize()
83
 
 
84
  if not answer.endswith('.'):
85
  answer += '.'
86
 
87
  return f"{answer}"
 
88
 
89
  def add_session(prompt):
90
  global previous_sessions
@@ -95,33 +62,60 @@ def add_session(prompt):
95
 
96
  return "\n".join(previous_sessions) # Return only the session logs as a string
97
 
98
- def clear_sessions():
99
- global previous_sessions
100
- previous_sessions.clear()
101
- return "\n".join(previous_sessions)
102
 
103
- def clear_fields():
104
- global reset_triggered
105
- reset_triggered = True
106
- return "", "" # Return empty strings to clear the prompt and output fields
107
-
108
- # Initialize the ChatState
109
- chat_state = ChatState(gemma_lm)
110
 
111
  def inference(prompt):
112
  global reset_triggered
113
  if reset_triggered:
 
114
  return "", ""
115
 
116
- result = chat_state.send_message(prompt)
 
117
 
118
- # Apply a bit of delay for a realistic response time
119
- time.sleep(1)
 
120
 
121
- sessions = add_session(prompt)
 
 
 
122
  return result, sessions
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
126
 
127
  gr.Markdown("<center><h1>HydroSense LLM Demo</h1></center>")
@@ -141,6 +135,9 @@ with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
141
  generate_btn = gr.Button("Generate Answer", variant="primary", size="sm")
142
  reset_btn = gr.Button("Clear Content", variant="secondary", size="sm", elem_id="primary")
143
 
 
 
 
144
  generate_btn.click(
145
  fn=inference,
146
  inputs=[prompt],
@@ -159,12 +156,15 @@ with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
159
  outputs=[prompt, output]
160
  )
161
 
 
 
162
  add_button.click(
163
  fn=clear_fields, # Only call the clear_fields function
164
  inputs=None, # No inputs needed
165
  outputs=[prompt, output] # Clear the prompt and output fields
166
  )
167
 
 
168
  clear_session.click(
169
  fn=clear_sessions,
170
  inputs=None,
 
1
+ #importing libraries
2
  import gradio as gr
3
  import tensorflow.keras as keras
4
  import time
5
  import keras_nlp
6
  import os
7
 
8
+
9
  model_path = "Zul001/HydroSense_Gemma_Finetuned_Model"
10
  gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(f"hf://{model_path}")
11
 
12
+
13
  reset_triggered = False
14
+
15
  custom_css = """
16
  @import url('https://fonts.googleapis.com/css2?family=Edu+AU+VIC+WA+NT+Dots:[email protected]&family=Give+You+Glory&family=Sofia&family=Sunshiney&family=Vujahday+Script&display=swap');
17
  .gradio-container, .gradio-container * {
 
31
  }
32
  """
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ previous_sessions = []
36
 
37
  def post_process_output(prompt, result):
38
+ # Remove the prompt if it's repeated at the beginning of the answer
39
  answer = result.strip()
40
  if answer.startswith(prompt):
41
  answer = answer[len(prompt):].strip()
42
 
43
+ # Remove any leading colons or whitespace
44
  answer = answer.lstrip(':')
45
+
46
+ # Ensure the answer starts with a capital letter
47
  answer = answer.capitalize()
48
 
49
+ # Ensure the answer ends with a period if it doesn't already
50
  if not answer.endswith('.'):
51
  answer += '.'
52
 
53
  return f"{answer}"
54
+
55
 
56
  def add_session(prompt):
57
  global previous_sessions
 
62
 
63
  return "\n".join(previous_sessions) # Return only the session logs as a string
64
 
 
 
 
 
65
 
 
 
 
 
 
 
 
66
 
67
  def inference(prompt):
68
  global reset_triggered
69
  if reset_triggered:
70
+ #do nothing
71
  return "", ""
72
 
73
+ prompt_text = prompt
74
+ generated_text = gemma_lm.generate(prompt_text)
75
 
76
+ #Apply post-processing
77
+ formatted_output = post_process_output(prompt_text, generated_text)
78
+ print(formatted_output)
79
 
80
+ #adding a bit of delay
81
+ time.sleep(1)
82
+ result = formatted_output
83
+ sessions = add_session(prompt_text)
84
  return result, sessions
85
 
86
 
87
+ # def inference(prompt):
88
+
89
+ # time.sleep(1)
90
+ # result = "Your Result"
91
+ # # sessions = add_session(prompt)
92
+ # return result
93
+
94
+
95
+ # def remember(prompt, result):
96
+ # global memory
97
+ # # Store the session as a dictionary
98
+ # session = {'prompt': prompt, 'result': result}
99
+ # memory.append(session)
100
+
101
+ # # Update previous_sessions for display
102
+ # session_display = [f"Q: {s['prompt']} \nA: {s['result']}" for s in memory]
103
+
104
+ # return "\n\n".join(session_display) # Return formatted sessions as a string
105
+
106
+
107
+
108
+ def clear_sessions():
109
+ global previous_sessions
110
+ previous_sessions.clear()
111
+ return "\n".join(previous_sessions)
112
+
113
+ def clear_fields():
114
+ global reset_triggered
115
+ reset_triggered = True
116
+ return "", "" # Return empty strings to clear the prompt and output fields
117
+
118
+
119
  with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
120
 
121
  gr.Markdown("<center><h1>HydroSense LLM Demo</h1></center>")
 
135
  generate_btn = gr.Button("Generate Answer", variant="primary", size="sm")
136
  reset_btn = gr.Button("Clear Content", variant="secondary", size="sm", elem_id="primary")
137
 
138
+
139
+
140
+
141
  generate_btn.click(
142
  fn=inference,
143
  inputs=[prompt],
 
156
  outputs=[prompt, output]
157
  )
158
 
159
+
160
+ # Button to clear the prompt and output fields
161
  add_button.click(
162
  fn=clear_fields, # Only call the clear_fields function
163
  inputs=None, # No inputs needed
164
  outputs=[prompt, output] # Clear the prompt and output fields
165
  )
166
 
167
+
168
  clear_session.click(
169
  fn=clear_sessions,
170
  inputs=None,