File size: 6,545 Bytes
330cb8a
1eb130f
e94bd1c
1eb130f
330cb8a
a93f397
4117bed
 
 
 
 
 
 
 
 
 
 
 
 
a93f397
4117bed
a93f397
4117bed
 
 
 
 
 
 
a93f397
4117bed
a93f397
4117bed
 
 
 
 
 
 
 
 
 
 
 
 
a93f397
 
 
 
4117bed
 
 
 
 
 
a93f397
4117bed
 
 
 
a93f397
4117bed
 
a93f397
 
 
 
 
1eb130f
4117bed
330cb8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7b49eb
330cb8a
e6f809f
 
330cb8a
 
 
 
 
 
 
 
 
 
 
212cd39
 
 
 
d7b49eb
 
 
 
 
 
 
 
 
 
 
 
 
 
e6f809f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import pandas as pd
import streamlit as st
from my_model.tabs.run_inference import run_inference


class UIManager:
    def __init__(self):
        self.tabs = {
            "Home": self.display_home,
            "Dataset Analysis": self.display_dataset_analysis,
            "Finetuning and Evaluation Results": self.display_finetuning_evaluation,
            "Run Inference": self.display_run_inference,
            "Dissertation Report": self.display_dissertation_report,
            "Code": self.display_code,
            "More Pages will follow .. ": self.display_placeholder
        }

    def add_tab(self, tab_name, display_function):
        self.tabs[tab_name] = display_function

    def display_sidebar(self):
        st.sidebar.title("Navigation")
        selection = st.sidebar.radio("Go to", list(self.tabs.keys()))
        st.sidebar.write("More Pages will follow .. ")
        return selection

    def display_selected_page(self, selection):
        if selection in self.tabs:
            self.tabs[selection]()

    def display_home(self):
        st.title("MultiModal Learning for Knowledge-Based Visual Question Answering")
        st.write("""This application is an interactive element of the project and prepared by Mohammed Alhaj as part of the dissertation for Masters degree in Artificial Intelligence at the University of Bath. 
                    Further details will be updated later""")

    def display_dataset_analysis(self):
        st.title("OK-VQA Dataset Analysis")
        st.write("This is a Place Holder until the contents are uploaded.")

    def display_finetuning_evaluation(self):
        st.title("Finetuning and Evaluation Results")
        st.write("This is a Place Holder until the contents are uploaded.")

    def display_run_inference(self):
        run_inference()

    def display_dissertation_report(self):
        st.title("Dissertation Report")
        st.write("Click the link below to view the PDF.")
        st.download_button(
            label="Download PDF",
            data=open("Files/Dissertation Report.pdf", "rb"),
            file_name="example.pdf",
            mime="application/octet-stream"
        )

    def display_code(self):
        st.title("Code")
        st.markdown("You can view the code for this project on the Hugging Face Space file page.")
        st.markdown("[View Code](https://huggingface.co/spaces/m7mdal7aj/Mohammed_Alhaj_PlayGround/tree/main)", unsafe_allow_html=True)

    def display_placeholder(self):
        st.title("Stay Tuned")
        st.write("This is a Place Holder until the contents are uploaded.")






class StateManager:
    def __init__(self):
        self.initialize_state()

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'model_settings' not in st.session_state:
            st.session_state['model_settings'] = {'detection_model': None, 'confidence_level': None}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if 'selected_method' not in st.session_state:
            st.session_state['selected_method'] = None

    def update_model_settings(self, detection_model=None, confidence_level=None, selected_method=None):
        if detection_model is not None:
            st.session_state['model_settings']['detection_model'] = detection_model
        if confidence_level is not None:
            st.session_state['model_settings']['confidence_level'] = confidence_level
        if selected_method is not None:
            st.session_state['selected_method'] = selected_method

    def check_settings_changed(self, current_selected_method, current_detection_model, current_confidence_level):
        return (st.session_state['model_settings']['detection_model'] != current_detection_model or
                st.session_state['model_settings']['confidence_level'] != current_confidence_level or
                st.session_state['selected_method'] != current_selected_method)

    def display_model_settings(self):
        st.write("### Current Model Settings:")
        st.table(pd.DataFrame(st.session_state['model_settings'], index=[0]))

    def display_session_state(self):
        st.write("### Current Session State:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)

    def get_model(self):
        """Retrieve the KBVQA model from the session state."""
        return st.session_state.get('kbvqa', None)

    def is_model_loaded(self):
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    def reload_detection_model(self, detection_model, confidence_level):
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(detection_model, only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = confidence_level
                self.update_model_settings(detection_model, confidence_level)
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    # New methods to be added
    def process_new_image(self, image_key, image, kbvqa):
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    def analyze_image(self, image, kbvqa):
        img = copy.deepcopy(image)
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    def add_to_qa_history(self, image_key, question, answer):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    def get_images_data(self):
        return st.session_state['images_data']

    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })