Update app.py
Browse files
app.py
CHANGED
@@ -6,10 +6,88 @@ import plotly.express as px
|
|
6 |
from collections import defaultdict
|
7 |
from datetime import datetime
|
8 |
import os
|
|
|
|
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def preprocess_dataset(test_data):
|
15 |
"""
|
@@ -32,7 +110,7 @@ def preprocess_dataset(test_data):
|
|
32 |
preprocessed_data.append(example)
|
33 |
return preprocessed_data
|
34 |
|
35 |
-
def evaluate_afrimmlu(test_data, model_name="deepseek-chat"):
|
36 |
"""
|
37 |
Evaluate the model on the AfriMMLU dataset.
|
38 |
"""
|
@@ -79,7 +157,6 @@ def evaluate_afrimmlu(test_data, model_name="deepseek-chat"):
|
|
79 |
total += 1
|
80 |
subject_results[subject]["total"] += 1
|
81 |
|
82 |
-
# Store detailed results
|
83 |
results.append({
|
84 |
'timestamp': datetime.now().isoformat(),
|
85 |
'subject': subject,
|
@@ -94,22 +171,14 @@ def evaluate_afrimmlu(test_data, model_name="deepseek-chat"):
|
|
94 |
print(f"Error processing question: {str(e)}")
|
95 |
continue
|
96 |
|
97 |
-
# Calculate accuracies
|
98 |
accuracy = (correct / total * 100) if total > 0 else 0
|
99 |
subject_accuracy = {
|
100 |
subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
|
101 |
for subject, stats in subject_results.items()
|
102 |
}
|
103 |
|
104 |
-
#
|
105 |
-
|
106 |
-
df.to_csv('detailed_results.csv', index=False)
|
107 |
-
|
108 |
-
# Export summary to CSV
|
109 |
-
summary_data = [{'subject': subject, 'accuracy': acc}
|
110 |
-
for subject, acc in subject_accuracy.items()]
|
111 |
-
summary_data.append({'subject': 'Overall', 'accuracy': accuracy})
|
112 |
-
pd.DataFrame(summary_data).to_csv('summary_results.csv', index=False)
|
113 |
|
114 |
return {
|
115 |
"accuracy": accuracy,
|
@@ -143,41 +212,34 @@ def create_visualization(results_dict):
|
|
143 |
|
144 |
return summary_df, fig
|
145 |
|
146 |
-
def evaluate_and_display(test_file, model_name):
|
147 |
-
# Load and preprocess data
|
148 |
-
test_data = pd.read_json(test_file.name)
|
149 |
-
preprocessed_data = preprocess_dataset(test_data.to_dict('records'))
|
150 |
-
|
151 |
-
# Run evaluation
|
152 |
-
results = evaluate_afrimmlu(preprocessed_data, model_name)
|
153 |
-
|
154 |
-
# Create visualizations
|
155 |
-
summary_df, plot = create_visualization(results)
|
156 |
-
|
157 |
-
# Load detailed results with error handling
|
158 |
-
try:
|
159 |
-
detailed_df = pd.read_csv('detailed_results.csv')
|
160 |
-
except (FileNotFoundError, pd.errors.EmptyDataError):
|
161 |
-
detailed_df = pd.DataFrame(results["detailed_results"])
|
162 |
-
|
163 |
-
return summary_df, plot, detailed_df
|
164 |
-
|
165 |
-
|
166 |
def create_gradio_interface():
|
167 |
"""
|
168 |
Create and configure the Gradio interface.
|
169 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
171 |
gr.Markdown("""
|
172 |
# AfriMMLU Evaluation Dashboard
|
173 |
-
|
174 |
""")
|
175 |
|
176 |
with gr.Row():
|
177 |
with gr.Column(scale=1):
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
181 |
)
|
182 |
model_input = gr.Dropdown(
|
183 |
choices=["deepseek/deepseek-chat"],
|
@@ -204,14 +266,29 @@ def create_gradio_interface():
|
|
204 |
wrap=True
|
205 |
)
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
evaluate_btn.click(
|
208 |
-
fn=
|
209 |
-
inputs=[
|
210 |
outputs=[summary_table, summary_plot, detailed_results]
|
211 |
)
|
212 |
|
213 |
return demo
|
214 |
|
215 |
if __name__ == "__main__":
|
|
|
|
|
|
|
216 |
demo = create_gradio_interface()
|
217 |
-
demo.launch(share=True)
|
|
|
6 |
from collections import defaultdict
|
7 |
from datetime import datetime
|
8 |
import os
|
9 |
+
from datasets import load_dataset
|
10 |
+
import sqlite3
|
11 |
|
12 |
+
def initialize_database():
|
13 |
+
conn = sqlite3.connect('afrimmlu_results.db')
|
14 |
+
cursor = conn.cursor()
|
15 |
+
|
16 |
+
cursor.execute('''
|
17 |
+
CREATE TABLE IF NOT EXISTS summary_results (
|
18 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
19 |
+
language TEXT,
|
20 |
+
subject TEXT,
|
21 |
+
accuracy REAL,
|
22 |
+
timestamp TEXT
|
23 |
+
)
|
24 |
+
''')
|
25 |
+
|
26 |
+
cursor.execute('''
|
27 |
+
CREATE TABLE IF NOT EXISTS detailed_results (
|
28 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
29 |
+
language TEXT,
|
30 |
+
timestamp TEXT,
|
31 |
+
subject TEXT,
|
32 |
+
question TEXT,
|
33 |
+
model_answer TEXT,
|
34 |
+
correct_answer TEXT,
|
35 |
+
is_correct INTEGER,
|
36 |
+
total_tokens INTEGER
|
37 |
+
)
|
38 |
+
''')
|
39 |
+
|
40 |
+
conn.commit()
|
41 |
+
conn.close()
|
42 |
+
|
43 |
+
def save_results_to_database(language, summary_results, detailed_results):
|
44 |
+
conn = sqlite3.connect('afrimmlu_results.db')
|
45 |
+
cursor = conn.cursor()
|
46 |
+
timestamp = datetime.now().isoformat()
|
47 |
+
|
48 |
+
# Save summary results
|
49 |
+
for subject, accuracy in summary_results.items():
|
50 |
+
cursor.execute('''
|
51 |
+
INSERT INTO summary_results (language, subject, accuracy, timestamp)
|
52 |
+
VALUES (?, ?, ?, ?)
|
53 |
+
''', (language, subject, accuracy, timestamp))
|
54 |
|
55 |
+
# Save detailed results
|
56 |
+
for result in detailed_results:
|
57 |
+
cursor.execute('''
|
58 |
+
INSERT INTO detailed_results (
|
59 |
+
language, timestamp, subject, question, model_answer,
|
60 |
+
correct_answer, is_correct, total_tokens
|
61 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
62 |
+
''', (
|
63 |
+
language,
|
64 |
+
result['timestamp'],
|
65 |
+
result['subject'],
|
66 |
+
result['question'],
|
67 |
+
result['model_answer'],
|
68 |
+
result['correct_answer'],
|
69 |
+
int(result['is_correct']),
|
70 |
+
result['total_tokens']
|
71 |
+
))
|
72 |
|
73 |
+
conn.commit()
|
74 |
+
conn.close()
|
75 |
+
|
76 |
+
def load_afrimmlu_data(language_code="swa"):
|
77 |
+
"""
|
78 |
+
Load AfriMMLU dataset for a specific language.
|
79 |
+
"""
|
80 |
+
try:
|
81 |
+
dataset = load_dataset(
|
82 |
+
'masakhane/afrimmlu',
|
83 |
+
language_code,
|
84 |
+
token=os.environ['HF_TOKEN'],
|
85 |
+
)
|
86 |
+
test_data = dataset['test'].to_list()
|
87 |
+
return test_data
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error loading dataset: {str(e)}")
|
90 |
+
return None
|
91 |
|
92 |
def preprocess_dataset(test_data):
|
93 |
"""
|
|
|
110 |
preprocessed_data.append(example)
|
111 |
return preprocessed_data
|
112 |
|
113 |
+
def evaluate_afrimmlu(test_data, model_name="deepseek/deepseek-chat", language="swa"):
|
114 |
"""
|
115 |
Evaluate the model on the AfriMMLU dataset.
|
116 |
"""
|
|
|
157 |
total += 1
|
158 |
subject_results[subject]["total"] += 1
|
159 |
|
|
|
160 |
results.append({
|
161 |
'timestamp': datetime.now().isoformat(),
|
162 |
'subject': subject,
|
|
|
171 |
print(f"Error processing question: {str(e)}")
|
172 |
continue
|
173 |
|
|
|
174 |
accuracy = (correct / total * 100) if total > 0 else 0
|
175 |
subject_accuracy = {
|
176 |
subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
|
177 |
for subject, stats in subject_results.items()
|
178 |
}
|
179 |
|
180 |
+
# Save results to database
|
181 |
+
save_results_to_database(language, {**subject_accuracy, 'Overall': accuracy}, results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
return {
|
184 |
"accuracy": accuracy,
|
|
|
212 |
|
213 |
return summary_df, fig
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
def create_gradio_interface():
|
216 |
"""
|
217 |
Create and configure the Gradio interface.
|
218 |
"""
|
219 |
+
language_options = {
|
220 |
+
"swa": "Swahili",
|
221 |
+
"yor": "Yoruba",
|
222 |
+
"wol": "Wolof",
|
223 |
+
"lin": "Lingala",
|
224 |
+
"ewe": "Ewe",
|
225 |
+
"ibo": "Igbo"
|
226 |
+
}
|
227 |
+
|
228 |
+
# Initialize database
|
229 |
+
initialize_database()
|
230 |
+
|
231 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
232 |
gr.Markdown("""
|
233 |
# AfriMMLU Evaluation Dashboard
|
234 |
+
Select a language and model to evaluate performance on the AfriMMLU benchmark.
|
235 |
""")
|
236 |
|
237 |
with gr.Row():
|
238 |
with gr.Column(scale=1):
|
239 |
+
language_input = gr.Dropdown(
|
240 |
+
choices=list(language_options.keys()),
|
241 |
+
label="Select Language",
|
242 |
+
value="swa"
|
243 |
)
|
244 |
model_input = gr.Dropdown(
|
245 |
choices=["deepseek/deepseek-chat"],
|
|
|
266 |
wrap=True
|
267 |
)
|
268 |
|
269 |
+
def evaluate_language(language_code, model_name):
|
270 |
+
test_data = load_afrimmlu_data(language_code)
|
271 |
+
if test_data is None:
|
272 |
+
return None, None, None
|
273 |
+
|
274 |
+
preprocessed_data = preprocess_dataset(test_data)
|
275 |
+
results = evaluate_afrimmlu(preprocessed_data, model_name, language_code)
|
276 |
+
summary_df, plot = create_visualization(results)
|
277 |
+
detailed_df = pd.DataFrame(results["detailed_results"])
|
278 |
+
|
279 |
+
return summary_df, plot, detailed_df
|
280 |
+
|
281 |
evaluate_btn.click(
|
282 |
+
fn=evaluate_language,
|
283 |
+
inputs=[language_input, model_input],
|
284 |
outputs=[summary_table, summary_plot, detailed_results]
|
285 |
)
|
286 |
|
287 |
return demo
|
288 |
|
289 |
if __name__ == "__main__":
|
290 |
+
os.environ['DEEPSEEK_API_KEY']
|
291 |
+
os.environ['HF_TOKEN']
|
292 |
+
|
293 |
demo = create_gradio_interface()
|
294 |
+
demo.launch(share=True)
|