m7mdal7aj commited on
Commit
fc1b9c5
·
verified ·
1 Parent(s): f20e4a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -106
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  import torch
3
  import bitsandbytes
@@ -10,151 +12,159 @@ from my_model.object_detection import detect_and_draw_objects
10
  from my_model.captioner.image_captioning import get_caption
11
  from my_model.gen_utilities import free_gpu_resources
12
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
13
- import my_model.utilities.st_config as st_config
14
 
15
 
16
 
 
17
 
 
 
18
 
19
- def analyze_image(image, model, show_processed_image=False):
20
- img = copy.deepcopy(image)
21
- caption = model.get_caption(img)
22
- image_with_boxes, detected_objects_str = model.detect_objects(img)
23
- if show_processed_image:
24
- st.image(image_with_boxes, use_column_width=True)
25
- return caption, detected_objects_str
26
 
 
 
27
 
 
 
 
 
28
 
29
- class QuestionAnswering:
30
- @staticmethod
31
- def answer_question(image, question, caption, detected_objects_str, model):
32
- answer = model.generate_answer(question, caption, detected_objects_str)
33
- st.image(image, use_column_width=True)
34
- st.write(caption)
35
- st.write("----------------")
36
- st.write(detected_objects_str)
37
- return answer
38
 
39
 
40
- def load_kbvqa_model(detection_model):
41
- """Load KBVQA Model based on the selected detection model."""
42
- if st.session_state.get('kbvqa') is not None:
43
- st.write("Model already loaded.")
44
- else:
45
- st.session_state['kbvqa'] = prepare_kbvqa_model(detection_model)
46
- if st.session_state['kbvqa']:
47
- st.write("Model is ready for inference.")
48
- return True
49
- return False
50
-
51
- def set_model_confidence(detection_model):
52
- """Set the confidence level for the detection model."""
53
- default_confidence = 0.2 if detection_model == "yolov5" else 0.4
54
- confidence_level = st.slider(
55
- "Select Detection Confidence Level",
56
- min_value=0.1,
57
- max_value=0.9,
58
- value=default_confidence,
59
- step=0.1
60
- )
61
- st.session_state['kbvqa'].detection_confidence = confidence_level
62
 
63
- def image_qa_app(kbvqa_model):
64
- """Streamlit app interface for image QA."""
65
- sample_images = st_config.SAMPLE_IMAGES
 
 
 
 
 
 
 
66
 
 
 
67
  cols = st.columns(len(sample_images))
68
  for idx, sample_image_path in enumerate(sample_images):
69
  with cols[idx]:
70
  image = Image.open(sample_image_path)
 
71
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
72
- st.session_state['current_image'] = sample_image_path
73
  st.session_state['qa_history'] = []
74
  st.session_state['analysis_done'] = False
75
  st.session_state['answer_in_progress'] = False
76
 
 
77
  uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
78
- st.image(uploaded_image, use_column_width=True)
79
  if uploaded_image is not None:
80
- st.session_state['current_image'] = uploaded_image
 
81
  st.session_state['qa_history'] = []
82
  st.session_state['analysis_done'] = False
83
  st.session_state['answer_in_progress'] = False
84
 
85
- # Display the image if it's in the session state
86
- if 'current_image' in st.session_state and st.session_state['current_image'] is not None:
87
- if isinstance(st.session_state['current_image'], str):
88
- # If it's a file path from sample images
89
- image_to_display = Image.open(st.session_state['current_image'])
90
- else:
91
- # If it's an uploaded file
92
- image_to_display = Image.open(st.session_state['current_image'])
93
- st.image(image_to_display, use_column_width=True)
94
- else:
95
- st.write("No image selected or uploaded.")
96
-
97
-
98
-
99
-
100
- if st.session_state.get('current_image') and not st.session_state.get('analysis_done', False):
101
  if st.button('Analyze Image'):
102
- caption, detected_objects_str = analyze_image(st.session_state['current_image'], kbvqa_model)
103
- st.session_state['caption'] = caption
104
- st.session_state['detected_objects_str'] = detected_objects_str
105
  st.session_state['analysis_done'] = True
 
 
 
 
 
106
 
107
- if st.session_state.get('analysis_done', False):
 
108
  question = st.text_input("Ask a question about this image:")
109
  if st.button('Get Answer'):
110
- answer = QuestionAnswering.answer_question(
111
- st.session_state['current_image'],
112
- question,
113
- st.session_state.get('caption', ''),
114
- st.session_state.get('detected_objects_str', ''),
115
- kbvqa_model
116
- )
117
  st.session_state['qa_history'].append((question, answer))
118
 
119
- for q, a in st.session_state.get('qa_history', []):
120
- st.text(f"Q: {q}\nA: {a}\n")
 
 
 
 
 
121
 
122
  def run_inference():
123
- """Main function to run inference based on the selected method."""
124
  st.title("Run Inference")
125
 
126
  method = st.selectbox(
127
  "Choose a method:",
128
  ["Fine-Tuned Model", "In-Context Learning (n-shots)"],
129
- index=0
130
  )
131
 
132
- if method == "Fine-Tuned Model":
133
- detection_model = st.selectbox(
134
- "Choose a model for object detection:",
135
- ["yolov5", "detic"],
136
- index=0
137
- )
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- if 'kbvqa' not in st.session_state or st.session_state['detection_model'] != detection_model:
140
- st.session_state['detection_model'] = detection_model
141
- if 'model' not in st.session_state:
142
- if st.button('Load Model'):
143
 
144
- if load_kbvqa_model(detection_model):
145
- set_model_confidence(detection_model)
146
- image_qa_app(st.session_state['kbvqa'])
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
 
149
 
 
 
150
  def main():
151
  st.sidebar.title("Navigation")
152
- selection = st.sidebar.radio("Go to", ["Home", "Dataset Analysis", "Evaluation Results", "Run Inference", "Dissertation Report"])
153
 
154
  if selection == "Home":
155
- st.title("MultiModal Learning for Knowledge-Based Visual Question Answering")
156
  st.write("Home page content goes here...")
157
-
158
  elif selection == "Dissertation Report":
159
  st.title("Dissertation Report")
160
  st.write("Click the link below to view the PDF.")
@@ -166,29 +176,22 @@ def main():
166
  mime="application/octet-stream"
167
  )
168
 
 
169
  elif selection == "Evaluation Results":
170
  st.title("Evaluation Results")
171
  st.write("This is a Place Holder until the contents are uploaded.")
172
 
 
173
  elif selection == "Dataset Analysis":
174
  st.title("OK-VQA Dataset Analysis")
175
  st.write("This is a Place Holder until the contents are uploaded.")
176
 
 
177
  elif selection == "Run Inference":
178
  run_inference()
179
-
180
-
181
-
182
-
183
-
184
-
185
 
186
  if __name__ == "__main__":
187
- main()
188
-
189
-
190
-
191
-
192
-
193
-
194
-
 
1
+
2
+
3
  import streamlit as st
4
  import torch
5
  import bitsandbytes
 
12
  from my_model.captioner.image_captioning import get_caption
13
  from my_model.gen_utilities import free_gpu_resources
14
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
 
15
 
16
 
17
 
18
+ def answer_question(image, question, model):
19
 
20
+ answer = model.generate_answer(question, image)
21
+ return answer
22
 
23
+ def get_caption(image):
24
+ return "Generated caption for the image"
 
 
 
 
 
25
 
26
+ def free_gpu_resources():
27
+ pass
28
 
29
+ # Sample images (assuming these are paths to your sample images)
30
+ sample_images = ["Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
31
+ "Files/sample4.jpg", "Files/sample5.jpg", "Files/sample6.jpg",
32
+ "Files/sample7.jpg"]
33
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
+ def analyze_image(image, model):
37
+ # Placeholder for your analysis function
38
+ # This function should prepare captions, detect objects, etc.
39
+ # For example:
40
+ # caption = model.get_caption(image)
41
+ # detected_objects = model.detect_objects(image)
42
+ # return caption, detected_objects
43
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def image_qa_app(kbvqa):
46
+ # Initialize session state for storing the current image and its Q&A history
47
+ if 'current_image' not in st.session_state:
48
+ st.session_state['current_image'] = None
49
+ if 'qa_history' not in st.session_state:
50
+ st.session_state['qa_history'] = []
51
+ if 'analysis_done' not in st.session_state:
52
+ st.session_state['analysis_done'] = False
53
+ if 'answer_in_progress' not in st.session_state:
54
+ st.session_state['answer_in_progress'] = False
55
 
56
+ # Display sample images as clickable thumbnails
57
+ st.write("Choose from sample images:")
58
  cols = st.columns(len(sample_images))
59
  for idx, sample_image_path in enumerate(sample_images):
60
  with cols[idx]:
61
  image = Image.open(sample_image_path)
62
+ st.image(image, use_column_width=True)
63
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
64
+ st.session_state['current_image'] = image
65
  st.session_state['qa_history'] = []
66
  st.session_state['analysis_done'] = False
67
  st.session_state['answer_in_progress'] = False
68
 
69
+ # Image uploader
70
  uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
 
71
  if uploaded_image is not None:
72
+ image = Image.open(uploaded_image)
73
+ st.session_state['current_image'] = image
74
  st.session_state['qa_history'] = []
75
  st.session_state['analysis_done'] = False
76
  st.session_state['answer_in_progress'] = False
77
 
78
+ # Analyze Image button
79
+ if st.session_state.get('current_image') and not st.session_state['analysis_done']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if st.button('Analyze Image'):
81
+ # Perform analysis on the image
82
+ analyze_image(st.session_state['current_image'], kbvqa)
 
83
  st.session_state['analysis_done'] = True
84
+ st.session_state['processed_image'] = copy.deepcopy(st.session_state['current_image'])
85
+
86
+ # Display the current image (unaltered)
87
+ if st.session_state.get('current_image'):
88
+ st.image(st.session_state['current_image'], caption='Uploaded Image.', use_column_width=True)
89
 
90
+ # Get Answer button
91
+ if st.session_state['analysis_done'] and not st.session_state['answer_in_progress']:
92
  question = st.text_input("Ask a question about this image:")
93
  if st.button('Get Answer'):
94
+ st.session_state['answer_in_progress'] = True
95
+ answer = answer_question(st.session_state['processed_image'], question, model=kbvqa)
 
 
 
 
 
96
  st.session_state['qa_history'].append((question, answer))
97
 
98
+ # Display all Q&A
99
+ for q, a in st.session_state['qa_history']:
100
+ st.text(f"Q: {q}\nA: {a}\n")
101
+
102
+ # Reset the answer_in_progress flag after displaying the answer
103
+ if st.session_state['answer_in_progress']:
104
+ st.session_state['answer_in_progress'] = False
105
 
106
  def run_inference():
 
107
  st.title("Run Inference")
108
 
109
  method = st.selectbox(
110
  "Choose a method:",
111
  ["Fine-Tuned Model", "In-Context Learning (n-shots)"],
112
+ index=0 # Default to the first option
113
  )
114
 
115
+ detection_model = st.selectbox(
116
+ "Choose a model for object detection:",
117
+ ["yolov5", "detic"],
118
+ index=0 # Default to the first option
119
+ )
120
+
121
+ # Set default confidence based on the selected model
122
+ default_confidence = 0.2 if detection_model == "yolov5" else 0.4
123
+
124
+ # Slider for confidence level
125
+ confidence_level = st.slider(
126
+ "Select Detection Confidence Level",
127
+ min_value=0.1,
128
+ max_value=0.9,
129
+ value=default_confidence,
130
+ step=0.1
131
+ )
132
 
 
 
 
 
133
 
134
+
135
+ # Initialize session state for the model
 
136
 
137
+ if method == "Fine-Tuned Model":
138
+ if 'kbvqa' not in st.session_state:
139
+ st.session_state['kbvqa'] = None
140
+
141
+ # Button to load KBVQA models
142
+ if st.button('Load KBVQA Model'):
143
+ if st.session_state['kbvqa'] is not None:
144
+ st.write("Model already loaded.")
145
+ else:
146
+ # Call the function to load models and show progress
147
+ st.session_state['kbvqa'] = prepare_kbvqa_model(detection_model)
148
+
149
+ if st.session_state['kbvqa']:
150
+ st.write("Model is ready for inference.")
151
+
152
+ if st.session_state['kbvqa']:
153
+ image_qa_app(st.session_state['kbvqa'])
154
 
155
+ else:
156
+ st.write('Model is not ready for inference yet')
157
 
158
+
159
+ # Main function
160
  def main():
161
  st.sidebar.title("Navigation")
162
+ selection = st.sidebar.radio("Go to", ["Home", "Dataset Analysis", "Evaluation Results", "Run Inference", "Dissertation Report", "Object Detection"])
163
 
164
  if selection == "Home":
165
+ st.title("MultiModal Learning for Knowledg-Based Visual Question Answering")
166
  st.write("Home page content goes here...")
167
+
168
  elif selection == "Dissertation Report":
169
  st.title("Dissertation Report")
170
  st.write("Click the link below to view the PDF.")
 
176
  mime="application/octet-stream"
177
  )
178
 
179
+
180
  elif selection == "Evaluation Results":
181
  st.title("Evaluation Results")
182
  st.write("This is a Place Holder until the contents are uploaded.")
183
 
184
+
185
  elif selection == "Dataset Analysis":
186
  st.title("OK-VQA Dataset Analysis")
187
  st.write("This is a Place Holder until the contents are uploaded.")
188
 
189
+
190
  elif selection == "Run Inference":
191
  run_inference()
192
+
193
+ elif selection == "Object Detection":
194
+ run_object_detection()
 
 
 
195
 
196
  if __name__ == "__main__":
197
+ main()