import ast import pandas as pd import gradio as gr import litellm import plotly.express as px from collections import defaultdict from datetime import datetime import os from datasets import load_dataset import sqlite3 def initialize_database(): conn = sqlite3.connect('afrimmlu_results.db') cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS summary_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, language TEXT, subject TEXT, accuracy REAL, timestamp TEXT ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS detailed_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, language TEXT, timestamp TEXT, subject TEXT, question TEXT, model_answer TEXT, correct_answer TEXT, is_correct INTEGER, total_tokens INTEGER ) ''') conn.commit() conn.close() def save_results_to_database(language, summary_results, detailed_results): conn = sqlite3.connect('afrimmlu_results.db') cursor = conn.cursor() timestamp = datetime.now().isoformat() # Save summary results for subject, accuracy in summary_results.items(): cursor.execute(''' INSERT INTO summary_results (language, subject, accuracy, timestamp) VALUES (?, ?, ?, ?) ''', (language, subject, accuracy, timestamp)) # Save detailed results for result in detailed_results: cursor.execute(''' INSERT INTO detailed_results ( language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( language, result['timestamp'], result['subject'], result['question'], result['model_answer'], result['correct_answer'], int(result['is_correct']), result['total_tokens'] )) conn.commit() conn.close() def load_afrimmlu_data(language_code="swa"): """ Load AfriMMLU dataset for a specific language. """ try: dataset = load_dataset( 'masakhane/afrimmlu', language_code, token=os.environ['HF_TOKEN'], ) test_data = dataset['test'].to_list() return test_data except Exception as e: print(f"Error loading dataset: {str(e)}") return None def preprocess_dataset(test_data): """ Preprocess the dataset to convert the 'choices' field from a string to a list of strings. """ preprocessed_data = [] for example in test_data: if isinstance(example['choices'], str): choices_str = example['choices'] if choices_str.startswith("'") and choices_str.endswith("'"): choices_str = choices_str[1:-1] elif choices_str.startswith('"') and choices_str.endswith('"'): choices_str = choices_str[1:-1] choices_str = choices_str.replace("\\'", "'") try: example['choices'] = ast.literal_eval(choices_str) except (ValueError, SyntaxError): print(f"Error parsing choices: {choices_str}") continue preprocessed_data.append(example) return preprocessed_data def evaluate_afrimmlu(test_data, model_name="deepseek/deepseek-chat", language="swa"): """ Evaluate the model on the AfriMMLU dataset. """ results = [] correct = 0 total = 0 subject_results = defaultdict(lambda: {"correct": 0, "total": 0}) for example in test_data: question = example['question'] choices = example['choices'] answer = example['answer'] subject = example['subject'] prompt = ( f"Answer the following multiple-choice question. " f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n" f"Question: {question}\n" f"Options:\n" f"A. {choices[0]}\n" f"B. {choices[1]}\n" f"C. {choices[2]}\n" f"D. {choices[3]}\n" f"Answer:" ) try: response = litellm.completion( model=model_name, messages=[{"role": "user", "content": prompt}] ) model_output = response.choices[0].message.content.strip().upper() model_answer = None for char in model_output: if char in ['A', 'B', 'C', 'D']: model_answer = char break is_correct = model_answer == answer.upper() if is_correct: correct += 1 subject_results[subject]["correct"] += 1 total += 1 subject_results[subject]["total"] += 1 results.append({ 'timestamp': datetime.now().isoformat(), 'subject': subject, 'question': question, 'model_answer': model_answer, 'correct_answer': answer.upper(), 'is_correct': is_correct, 'total_tokens': response.usage.total_tokens }) except Exception as e: print(f"Error processing question: {str(e)}") continue accuracy = (correct / total * 100) if total > 0 else 0 subject_accuracy = { subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0 for subject, stats in subject_results.items() } # Save results to database save_results_to_database(language, {**subject_accuracy, 'Overall': accuracy}, results) return { "accuracy": accuracy, "subject_accuracy": subject_accuracy, "detailed_results": results } def create_visualization(results_dict): """ Create visualization from evaluation results. """ summary_data = [ {'Subject': subject, 'Accuracy (%)': accuracy} for subject, accuracy in results_dict['subject_accuracy'].items() ] summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results_dict['accuracy']}) summary_df = pd.DataFrame(summary_data) fig = px.bar( summary_df, x='Subject', y='Accuracy (%)', title='AfriMMLU Evaluation Results', labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'} ) fig.update_layout( xaxis_tickangle=-45, showlegend=False, height=600 ) return summary_df, fig def query_database(query): conn = sqlite3.connect('afrimmlu_results.db') try: df = pd.read_sql_query(query, conn) return df except Exception as e: return pd.DataFrame({'Error': [str(e)]}) finally: conn.close() def create_gradio_interface(): language_options = { "swa": "Swahili", "yor": "Yoruba", "wol": "Wolof", "lin": "Lingala", "ewe": "Ewe", "ibo": "Igbo" } initialize_database() with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# AfriMMLU Evaluation Dashboard") with gr.Tabs(): # Evaluation Tab with gr.Tab("Model Evaluation"): with gr.Row(): with gr.Column(scale=1): language_input = gr.Dropdown( choices=list(language_options.keys()), label="Select Language", value="swa" ) model_input = gr.Dropdown( choices=["deepseek/deepseek-chat"], label="Select Model", value="deepseek/deepseek-chat" ) evaluate_btn = gr.Button("Evaluate", variant="primary") with gr.Row(): summary_table = gr.Dataframe( headers=["Subject", "Accuracy (%)"], label="Summary Results" ) with gr.Row(): summary_plot = gr.Plot(label="Performance by Subject") with gr.Row(): detailed_results = gr.Dataframe( label="Detailed Results", wrap=True ) # Query Tab with gr.Tab("Database Analysis"): with gr.Row(): with gr.Column(): example_queries = gr.Dropdown( choices=[ "SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language", "SELECT subject, AVG(accuracy) as avg_accuracy FROM summary_results GROUP BY subject", "SELECT language, subject, accuracy, timestamp FROM summary_results ORDER BY timestamp DESC LIMIT 10", "SELECT language, COUNT(*) as total_questions, SUM(is_correct) as correct_answers FROM detailed_results GROUP BY language", "SELECT subject, COUNT(*) as total_evaluations FROM summary_results GROUP BY subject" ], label="Example Queries", value="SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language" ) query_input = gr.Textbox( label="SQL Query", placeholder="Enter your SQL query here", lines=3 ) query_button = gr.Button("Run Query", variant="primary") gr.Markdown(""" ### Available Tables: 1. summary_results (id, language, subject, accuracy, timestamp) 2. detailed_results (id, language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens) """) with gr.Row(): query_output = gr.Dataframe( label="Query Results", wrap=True ) def evaluate_language(language_code, model_name): test_data = load_afrimmlu_data(language_code) if test_data is None: return None, None, None preprocessed_data = preprocess_dataset(test_data) results = evaluate_afrimmlu(preprocessed_data, model_name, language_code) summary_df, plot = create_visualization(results) detailed_df = pd.DataFrame(results["detailed_results"]) return summary_df, plot, detailed_df # Evaluation tab callback evaluate_btn.click( fn=evaluate_language, inputs=[language_input, model_input], outputs=[summary_table, summary_plot, detailed_results] ) # Query tab callbacks example_queries.change( fn=lambda x: x, inputs=[example_queries], outputs=[query_input] ) query_button.click( fn=query_database, inputs=[query_input], outputs=[query_output] ) return demo if __name__ == "__main__": os.environ['DEEPSEEK_API_KEY'] os.environ['HF_TOKEN'] demo = create_gradio_interface() demo.launch(share=True)