m7mdal7aj commited on
Commit
a650af8
·
verified ·
1 Parent(s): 8a2cc2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -113
app.py CHANGED
@@ -10,151 +10,140 @@ from my_model.object_detection import detect_and_draw_objects
10
  from my_model.captioner.image_captioning import get_caption
11
  from my_model.utilities import free_gpu_resources
12
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
14
 
15
-
16
- def answer_question(image, question, caption, detected_objects_str, model):
17
-
18
- answer = model.generate_answer(question, caption, detected_objects_str)
19
- st.image(image)
20
- st.write(caption)
21
- st.write("----------------")
22
- st.write(detected_objects_str)
23
- return answer
24
-
25
- def get_caption(image):
26
- return "Generated caption for the image"
27
-
28
- def free_gpu_resources():
29
- pass
30
-
31
- # Sample images (assuming these are paths to your sample images)
32
- sample_images = ["Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
33
- "Files/sample4.jpg", "Files/sample5.jpg", "Files/sample6.jpg",
34
- "Files/sample7.jpg"]
35
-
36
-
37
-
38
- def analyze_image(image, model, show_processed_image=False):
39
- img = copy.deepcopy(image)
40
- caption = model.get_caption(img)
41
- image_with_boxes, detected_objects_str = model.detect_objects(img)
42
- if show_processed_image:
43
- st.image(image_with_boxes)
44
- return caption, detected_objects_str
45
-
46
- def image_qa_app(kbvqa):
47
- # Initialize session state for storing the current image and its Q&A history
48
- if 'current_image' not in st.session_state:
49
- st.session_state['current_image'] = None
50
- if 'qa_history' not in st.session_state:
51
- st.session_state['qa_history'] = []
52
- if 'analysis_done' not in st.session_state:
53
- st.session_state['analysis_done'] = False
54
- if 'answer_in_progress' not in st.session_state:
55
- st.session_state['answer_in_progress'] = False
56
-
57
- # Display sample images as clickable thumbnails
58
- st.write("Choose from sample images:")
59
- cols = st.columns(len(sample_images))
60
- for idx, sample_image_path in enumerate(sample_images):
61
- with cols[idx]:
62
- image = Image.open(sample_image_path)
63
- st.image(image, use_column_width=True)
64
- if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
65
- st.session_state['current_image'] = image
66
- st.session_state['qa_history'] = []
67
- st.session_state['analysis_done'] = False
68
- st.session_state['answer_in_progress'] = False
69
-
70
- # Image uploader
71
  uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
72
  if uploaded_image is not None:
73
- image = Image.open(uploaded_image)
74
- st.session_state['current_image'] = image
75
  st.session_state['qa_history'] = []
76
  st.session_state['analysis_done'] = False
77
  st.session_state['answer_in_progress'] = False
78
 
79
  if st.session_state.get('current_image') and not st.session_state.get('analysis_done', False):
80
  if st.button('Analyze Image'):
81
- caption, detected_objects_str = analyze_image(st.session_state['current_image'], kbvqa)
82
  st.session_state['caption'] = caption
83
  st.session_state['detected_objects_str'] = detected_objects_str
84
  st.session_state['analysis_done'] = True
85
 
86
- # Get Answer button
87
  if st.session_state.get('analysis_done', False):
88
  question = st.text_input("Ask a question about this image:")
89
  if st.button('Get Answer'):
90
- answer = answer_question(st.session_state['current_image'], question, st.session_state.get('caption', ''), st.session_state.get('detected_objects_str', ''), kbvqa)
 
 
 
 
 
 
91
  st.session_state['qa_history'].append((question, answer))
92
 
93
- # Display all Q&A
94
  for q, a in st.session_state.get('qa_history', []):
95
  st.text(f"Q: {q}\nA: {a}\n")
96
 
97
- # Reset the answer_in_progress flag after displaying the answer
98
- if st.session_state['answer_in_progress']:
99
- st.session_state['answer_in_progress'] = False
100
-
101
  def run_inference():
 
102
  st.title("Run Inference")
103
 
104
  method = st.selectbox(
105
  "Choose a method:",
106
  ["Fine-Tuned Model", "In-Context Learning (n-shots)"],
107
- index=0 # Default to the first option
108
  )
109
 
110
- detection_model = st.selectbox(
111
- "Choose a model for object detection:",
112
- ["yolov5", "detic"],
113
- index=0 # Default to the first option
114
- )
115
-
116
- # Initialize session state for the model
117
  if method == "Fine-Tuned Model":
118
- if 'kbvqa' not in st.session_state:
119
- st.session_state['kbvqa'] = None
120
-
121
- # Button to load KBVQA models
122
- if st.button('Load KBVQA Model'):
123
- if st.session_state['kbvqa'] is not None:
124
- st.write("Model already loaded.")
125
- else:
126
- # Call the function to load models and show progress
127
- st.session_state['kbvqa'] = prepare_kbvqa_model(detection_model)
128
-
129
- if st.session_state['kbvqa']:
130
- st.write("Model is ready for inference.")
131
- # Set default confidence based on the selected model
132
- default_confidence = 0.2 if detection_model == "yolov5" else 0.4
133
- # Slider for confidence level
134
- confidence_level = st.slider(
135
- "Select Detection Confidence Level",
136
- min_value=0.1,
137
- max_value=0.9,
138
- value=default_confidence,
139
- step=0.1
140
- )
141
- st.session_state['kbvqa'].detection_confidence = confidence_level
142
 
143
- if st.session_state['kbvqa']:
144
- image_qa_app(st.session_state['kbvqa'])
145
 
146
- else:
147
- st.write(f'{method} model is not ready for inference yet')
148
-
149
- # Main function
150
  def main():
151
  st.sidebar.title("Navigation")
152
  selection = st.sidebar.radio("Go to", ["Home", "Dataset Analysis", "Evaluation Results", "Run Inference", "Dissertation Report"])
153
 
154
  if selection == "Home":
155
- st.title("MultiModal Learning for Knowledg-Based Visual Question Answering")
156
  st.write("Home page content goes here...")
157
-
158
  elif selection == "Dissertation Report":
159
  st.title("Dissertation Report")
160
  st.write("Click the link below to view the PDF.")
@@ -166,22 +155,29 @@ def main():
166
  mime="application/octet-stream"
167
  )
168
 
169
-
170
  elif selection == "Evaluation Results":
171
  st.title("Evaluation Results")
172
  st.write("This is a Place Holder until the contents are uploaded.")
173
 
174
-
175
  elif selection == "Dataset Analysis":
176
  st.title("OK-VQA Dataset Analysis")
177
  st.write("This is a Place Holder until the contents are uploaded.")
178
 
179
-
180
  elif selection == "Run Inference":
181
  run_inference()
182
-
183
- elif selection == "Object Detection":
184
- run_object_detection()
 
 
 
185
 
186
  if __name__ == "__main__":
187
- main()
 
 
 
 
 
 
 
 
10
  from my_model.captioner.image_captioning import get_caption
11
  from my_model.utilities import free_gpu_resources
12
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
13
+ import my_model.utilities.st_config as st_config
14
+
15
+
16
+
17
+
18
+ class ImageHandler:
19
+ @staticmethod
20
+ def analyze_image(image, model, show_processed_image=False):
21
+ img = copy.deepcopy(image)
22
+ caption = model.get_caption(img)
23
+ image_with_boxes, detected_objects_str = model.detect_objects(img)
24
+ if show_processed_image:
25
+ st.image(image_with_boxes)
26
+ return caption, detected_objects_str
27
+
28
+ @staticmethod
29
+ def free_gpu_resources():
30
+ # Implementation for freeing GPU resources
31
+ free_gpu_resources()
32
+
33
+ class QuestionAnswering:
34
+ @staticmethod
35
+ def answer_question(image, question, caption, detected_objects_str, model):
36
+ answer = model.generate_answer(question, caption, detected_objects_str)
37
+ st.image(image)
38
+ st.write(caption)
39
+ st.write("----------------")
40
+ st.write(detected_objects_str)
41
+ return answer
42
+
43
+ class UIComponents:
44
+ @staticmethod
45
+ def display_image_selection(sample_images):
46
+ cols = st.columns(len(sample_images))
47
+ for idx, sample_image_path in enumerate(sample_images):
48
+ with cols[idx]:
49
+ image = Image.open(sample_image_path)
50
+ st.image(image, use_column_width=True)
51
+ if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
52
+ st.session_state['current_image'] = image
53
+ st.session_state['qa_history'] = []
54
+ st.session_state['analysis_done'] = False
55
+ st.session_state['answer_in_progress'] = False
56
+
57
+ def load_kbvqa_model(detection_model):
58
+ """Load KBVQA Model based on the selected detection model."""
59
+ if st.session_state.get('kbvqa') is not None:
60
+ st.write("Model already loaded.")
61
+ else:
62
+ st.session_state['kbvqa'] = prepare_kbvqa_model(detection_model)
63
+ if st.session_state['kbvqa']:
64
+ st.write("Model is ready for inference.")
65
+ return True
66
+ return False
67
+
68
+ def set_model_confidence(detection_model):
69
+ """Set the confidence level for the detection model."""
70
+ default_confidence = 0.2 if detection_model == "yolov5" else 0.4
71
+ confidence_level = st.slider(
72
+ "Select Detection Confidence Level",
73
+ min_value=0.1,
74
+ max_value=0.9,
75
+ value=default_confidence,
76
+ step=0.1
77
+ )
78
+ st.session_state['kbvqa'].detection_confidence = confidence_level
79
 
80
+ def image_qa_app(kbvqa_model):
81
+ """Streamlit app interface for image QA."""
82
+ sample_images = st_config.SAMPLE_IMAGES
83
+ UIComponents.display_image_selection(sample_images)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
86
  if uploaded_image is not None:
87
+ st.session_state['current_image'] = Image.open(uploaded_image)
 
88
  st.session_state['qa_history'] = []
89
  st.session_state['analysis_done'] = False
90
  st.session_state['answer_in_progress'] = False
91
 
92
  if st.session_state.get('current_image') and not st.session_state.get('analysis_done', False):
93
  if st.button('Analyze Image'):
94
+ caption, detected_objects_str = ImageHandler.analyze_image(st.session_state['current_image'], kbvqa_model)
95
  st.session_state['caption'] = caption
96
  st.session_state['detected_objects_str'] = detected_objects_str
97
  st.session_state['analysis_done'] = True
98
 
 
99
  if st.session_state.get('analysis_done', False):
100
  question = st.text_input("Ask a question about this image:")
101
  if st.button('Get Answer'):
102
+ answer = QuestionAnswering.answer_question(
103
+ st.session_state['current_image'],
104
+ question,
105
+ st.session_state.get('caption', ''),
106
+ st.session_state.get('detected_objects_str', ''),
107
+ kbvqa_model
108
+ )
109
  st.session_state['qa_history'].append((question, answer))
110
 
 
111
  for q, a in st.session_state.get('qa_history', []):
112
  st.text(f"Q: {q}\nA: {a}\n")
113
 
 
 
 
 
114
  def run_inference():
115
+ """Main function to run inference based on the selected method."""
116
  st.title("Run Inference")
117
 
118
  method = st.selectbox(
119
  "Choose a method:",
120
  ["Fine-Tuned Model", "In-Context Learning (n-shots)"],
121
+ index=0
122
  )
123
 
 
 
 
 
 
 
 
124
  if method == "Fine-Tuned Model":
125
+ detection_model = st.selectbox(
126
+ "Choose a model for object detection:",
127
+ ["yolov5", "detic"],
128
+ index=0
129
+ )
130
+
131
+ if 'kbvqa' not in st.session_state or st.session_state['detection_model'] != detection_model:
132
+ st.session_state['detection_model'] = detection_model
133
+ if load_kbvqa_model(detection_model):
134
+ set_model_confidence(detection_model)
135
+ image_qa_app(st.session_state['kbvqa'])
136
+
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
138
 
 
 
 
 
139
  def main():
140
  st.sidebar.title("Navigation")
141
  selection = st.sidebar.radio("Go to", ["Home", "Dataset Analysis", "Evaluation Results", "Run Inference", "Dissertation Report"])
142
 
143
  if selection == "Home":
144
+ st.title("MultiModal Learning for Knowledge-Based Visual Question Answering")
145
  st.write("Home page content goes here...")
146
+
147
  elif selection == "Dissertation Report":
148
  st.title("Dissertation Report")
149
  st.write("Click the link below to view the PDF.")
 
155
  mime="application/octet-stream"
156
  )
157
 
 
158
  elif selection == "Evaluation Results":
159
  st.title("Evaluation Results")
160
  st.write("This is a Place Holder until the contents are uploaded.")
161
 
 
162
  elif selection == "Dataset Analysis":
163
  st.title("OK-VQA Dataset Analysis")
164
  st.write("This is a Place Holder until the contents are uploaded.")
165
 
 
166
  elif selection == "Run Inference":
167
  run_inference()
168
+
169
+
170
+
171
+
172
+
173
+
174
 
175
  if __name__ == "__main__":
176
+ main()
177
+
178
+
179
+
180
+
181
+
182
+
183
+