m7mdal7aj commited on
Commit
e6f809f
·
verified ·
1 Parent(s): d7b49eb

Update my_model/utilities/st_utils.py

Browse files
Files changed (1) hide show
  1. my_model/utilities/st_utils.py +33 -14
my_model/utilities/st_utils.py CHANGED
@@ -92,8 +92,8 @@ class StateManager:
92
 
93
  def check_settings_changed(self, current_selected_method, current_detection_model, current_confidence_level):
94
  return (st.session_state['model_settings']['detection_model'] != current_detection_model or
95
- st.session_state['model_settings']['confidence_level'] != current_confidence_level
96
- st.session_state['model_settings']['selected_method'] != current_selected_method)
97
 
98
  def display_model_settings(self):
99
  st.write("### Current Model Settings:")
@@ -106,12 +106,9 @@ class StateManager:
106
  st.table(df)
107
 
108
  def is_model_loaded(self):
109
- """Check if the model is loaded in the session state."""
110
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
111
 
112
-
113
  def reload_detection_model(self, detection_model, confidence_level):
114
- """Reload only the detection model with new settings."""
115
  try:
116
  free_gpu_resources()
117
  if self.is_model_loaded():
@@ -122,12 +119,34 @@ class StateManager:
122
  except Exception as e:
123
  st.error(f"Error reloading detection model: {e}")
124
 
125
-
126
-
127
-
128
-
129
-
130
-
131
-
132
-
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def check_settings_changed(self, current_selected_method, current_detection_model, current_confidence_level):
94
  return (st.session_state['model_settings']['detection_model'] != current_detection_model or
95
+ st.session_state['model_settings']['confidence_level'] != current_confidence_level or
96
+ st.session_state['selected_method'] != current_selected_method)
97
 
98
  def display_model_settings(self):
99
  st.write("### Current Model Settings:")
 
106
  st.table(df)
107
 
108
  def is_model_loaded(self):
 
109
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
110
 
 
111
  def reload_detection_model(self, detection_model, confidence_level):
 
112
  try:
113
  free_gpu_resources()
114
  if self.is_model_loaded():
 
119
  except Exception as e:
120
  st.error(f"Error reloading detection model: {e}")
121
 
122
+ # New methods to be added
123
+ def process_new_image(self, image_key, image, kbvqa):
124
+ if image_key not in st.session_state['images_data']:
125
+ st.session_state['images_data'][image_key] = {
126
+ 'image': image,
127
+ 'caption': '',
128
+ 'detected_objects_str': '',
129
+ 'qa_history': [],
130
+ 'analysis_done': False
131
+ }
132
+
133
+ def analyze_image(self, image, kbvqa):
134
+ img = copy.deepcopy(image)
135
+ caption = kbvqa.get_caption(img)
136
+ image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
137
+ return caption, detected_objects_str, image_with_boxes
138
+
139
+ def add_to_qa_history(self, image_key, question, answer):
140
+ if image_key in st.session_state['images_data']:
141
+ st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
142
+
143
+ def get_images_data(self):
144
+ return st.session_state['images_data']
145
+
146
+ def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
147
+ if image_key in st.session_state['images_data']:
148
+ st.session_state['images_data'][image_key].update({
149
+ 'caption': caption,
150
+ 'detected_objects_str': detected_objects_str,
151
+ 'analysis_done': analysis_done
152
+ })