|
"""Script to produce radial plots.""" |
|
|
|
from functools import partial |
|
import plotly.graph_objects as go |
|
import json |
|
import numpy as np |
|
from collections import defaultdict |
|
import pandas as pd |
|
from pydantic import BaseModel |
|
import gradio as gr |
|
import requests |
|
import random |
|
import logging |
|
import datetime as dt |
|
|
|
|
|
fmt = "%(asctime)s [%(levelname)s] <%(name)s> %(message)s" |
|
logging.basicConfig(level=logging.INFO, format=fmt) |
|
logger = logging.getLogger("radial_plot_generator") |
|
|
|
|
|
UPDATE_FREQUENCY_MINUTES = 30 |
|
|
|
|
|
class Task(BaseModel): |
|
"""Class to hold task information.""" |
|
|
|
name: str |
|
metric: str |
|
|
|
def __hash__(self): |
|
return hash(self.name) |
|
|
|
|
|
class Language(BaseModel): |
|
"""Class to hold language information.""" |
|
|
|
code: str |
|
name: str |
|
|
|
def __hash__(self): |
|
return hash(self.code) |
|
|
|
|
|
class Dataset(BaseModel): |
|
"""Class to hold dataset information.""" |
|
|
|
name: str |
|
language: Language |
|
task: Task |
|
|
|
def __hash__(self): |
|
return hash(self.name) |
|
|
|
|
|
TEXT_CLASSIFICATION = Task(name="text classification", metric="mcc") |
|
INFORMATION_EXTRACTION = Task(name="information extraction", metric="micro_f1_no_misc") |
|
GRAMMAR = Task(name="grammar", metric="mcc") |
|
QUESTION_ANSWERING = Task(name="question answering", metric="em") |
|
SUMMARISATION = Task(name="summarisation", metric="bertscore") |
|
KNOWLEDGE = Task(name="knowledge", metric="mcc") |
|
REASONING = Task(name="reasoning", metric="mcc") |
|
ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)] |
|
|
|
DANISH = Language(code="da", name="Danish") |
|
NORWEGIAN = Language(code="no", name="Norwegian") |
|
SWEDISH = Language(code="sv", name="Swedish") |
|
ICELANDIC = Language(code="is", name="Icelandic") |
|
FAROESE = Language(code="fo", name="Faroese") |
|
GERMAN = Language(code="de", name="German") |
|
DUTCH = Language(code="nl", name="Dutch") |
|
ENGLISH = Language(code="en", name="English") |
|
ALL_LANGUAGES = { |
|
obj.name: obj for obj in globals().values() if isinstance(obj, Language) |
|
} |
|
|
|
DATASETS = [ |
|
Dataset(name="swerec", language=SWEDISH, task=TEXT_CLASSIFICATION), |
|
Dataset(name="angry-tweets", language=DANISH, task=TEXT_CLASSIFICATION), |
|
Dataset(name="norec", language=NORWEGIAN, task=TEXT_CLASSIFICATION), |
|
Dataset(name="sb10k", language=GERMAN, task=TEXT_CLASSIFICATION), |
|
Dataset(name="dutch-social", language=DUTCH, task=TEXT_CLASSIFICATION), |
|
Dataset(name="sst5", language=ENGLISH, task=TEXT_CLASSIFICATION), |
|
Dataset(name="suc3", language=SWEDISH, task=INFORMATION_EXTRACTION), |
|
Dataset(name="dansk", language=DANISH, task=INFORMATION_EXTRACTION), |
|
Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION), |
|
Dataset(name="norne-nn", language=NORWEGIAN, task=INFORMATION_EXTRACTION), |
|
Dataset(name="mim-gold-ner", language=ICELANDIC, task=INFORMATION_EXTRACTION), |
|
Dataset(name="fone", language=FAROESE, task=INFORMATION_EXTRACTION), |
|
Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION), |
|
Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION), |
|
Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION), |
|
Dataset(name="scala-sv", language=SWEDISH, task=GRAMMAR), |
|
Dataset(name="scala-da", language=DANISH, task=GRAMMAR), |
|
Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR), |
|
Dataset(name="scala-nn", language=NORWEGIAN, task=GRAMMAR), |
|
Dataset(name="scala-is", language=ICELANDIC, task=GRAMMAR), |
|
Dataset(name="scala-fo", language=FAROESE, task=GRAMMAR), |
|
Dataset(name="scala-de", language=GERMAN, task=GRAMMAR), |
|
Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR), |
|
Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR), |
|
Dataset(name="scandiqa-da", language=DANISH, task=QUESTION_ANSWERING), |
|
Dataset(name="norquad", language=NORWEGIAN, task=QUESTION_ANSWERING), |
|
Dataset(name="scandiqa-sv", language=SWEDISH, task=QUESTION_ANSWERING), |
|
Dataset(name="nqii", language=ICELANDIC, task=QUESTION_ANSWERING), |
|
Dataset(name="germanquad", language=GERMAN, task=QUESTION_ANSWERING), |
|
Dataset(name="squad", language=ENGLISH, task=QUESTION_ANSWERING), |
|
Dataset(name="squad-nl", language=DUTCH, task=QUESTION_ANSWERING), |
|
Dataset(name="nordjylland-news", language=DANISH, task=SUMMARISATION), |
|
Dataset(name="mlsum", language=GERMAN, task=SUMMARISATION), |
|
Dataset(name="rrn", language=ICELANDIC, task=SUMMARISATION), |
|
Dataset(name="no-sammendrag", language=NORWEGIAN, task=SUMMARISATION), |
|
Dataset(name="wiki-lingua-nl", language=DUTCH, task=SUMMARISATION), |
|
Dataset(name="swedn", language=SWEDISH, task=SUMMARISATION), |
|
Dataset(name="cnn-dailymail", language=ENGLISH, task=SUMMARISATION), |
|
Dataset(name="mmlu-da", language=DANISH, task=KNOWLEDGE), |
|
Dataset(name="mmlu-no", language=NORWEGIAN, task=KNOWLEDGE), |
|
Dataset(name="mmlu-sv", language=SWEDISH, task=KNOWLEDGE), |
|
Dataset(name="mmlu-is", language=ICELANDIC, task=KNOWLEDGE), |
|
Dataset(name="mmlu-de", language=GERMAN, task=KNOWLEDGE), |
|
Dataset(name="mmlu-nl", language=DUTCH, task=KNOWLEDGE), |
|
Dataset(name="mmlu", language=ENGLISH, task=KNOWLEDGE), |
|
Dataset(name="arc-da", language=DANISH, task=KNOWLEDGE), |
|
Dataset(name="arc-no", language=NORWEGIAN, task=KNOWLEDGE), |
|
Dataset(name="arc-sv", language=SWEDISH, task=KNOWLEDGE), |
|
Dataset(name="arc-is", language=ICELANDIC, task=KNOWLEDGE), |
|
Dataset(name="arc-de", language=GERMAN, task=KNOWLEDGE), |
|
Dataset(name="arc-nl", language=DUTCH, task=KNOWLEDGE), |
|
Dataset(name="arc", language=ENGLISH, task=KNOWLEDGE), |
|
Dataset(name="hellaswag-da", language=DANISH, task=REASONING), |
|
Dataset(name="hellaswag-no", language=NORWEGIAN, task=REASONING), |
|
Dataset(name="hellaswag-sv", language=SWEDISH, task=REASONING), |
|
Dataset(name="hellaswag-is", language=ICELANDIC, task=REASONING), |
|
Dataset(name="hellaswag-de", language=GERMAN, task=REASONING), |
|
Dataset(name="hellaswag-nl", language=DUTCH, task=REASONING), |
|
Dataset(name="hellaswag", language=ENGLISH, task=REASONING), |
|
] |
|
|
|
|
|
def main() -> None: |
|
"""Produce a radial plot.""" |
|
|
|
global last_fetch |
|
results_dfs = fetch_results() |
|
last_fetch = dt.datetime.now() |
|
|
|
all_languages = [ |
|
language.name for language in ALL_LANGUAGES.values() |
|
] |
|
danish_models = list({ |
|
model_id |
|
for model_id in results_dfs[DANISH].index |
|
}) |
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: |
|
gr.Markdown("# Radial Plot Generator") |
|
gr.Markdown( |
|
"This demo allows you to generate a radial plot comparing the performance " |
|
"of different language models on different tasks. It is based on the " |
|
"generative results from the [ScandEval benchmark](https://scandeval.com)." |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
language_names_dropdown = gr.Dropdown( |
|
choices=all_languages, |
|
multiselect=True, |
|
label="Languages", |
|
value=["Danish"], |
|
interactive=True, |
|
scale=2, |
|
) |
|
model_ids_dropdown = gr.Dropdown( |
|
choices=danish_models, |
|
multiselect=True, |
|
label="Models", |
|
value=["gpt-4-0613", "mistralai/Mistral-7B-v0.1"], |
|
interactive=True, |
|
scale=2, |
|
) |
|
use_win_ratio_checkbox = gr.Checkbox( |
|
label="Compare models with win ratios (as opposed to raw scores)", |
|
value=True, |
|
interactive=True, |
|
scale=1, |
|
) |
|
with gr.Row(): |
|
plot = gr.Plot( |
|
value=produce_radial_plot( |
|
model_ids_dropdown.value, |
|
language_names=language_names_dropdown.value, |
|
use_win_ratio=use_win_ratio_checkbox.value, |
|
results_dfs=results_dfs, |
|
), |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"<center>Made with ❤️ by the <a href=\"https://alexandra.dk\">" |
|
"Alexandra Institute</a>.</center>" |
|
) |
|
|
|
language_names_dropdown.change( |
|
fn=partial(update_model_ids_dropdown, results_dfs=results_dfs), |
|
inputs=[language_names_dropdown, model_ids_dropdown], |
|
outputs=model_ids_dropdown, |
|
) |
|
|
|
|
|
language_names_dropdown.change( |
|
fn=partial(produce_radial_plot, results_dfs=results_dfs), |
|
inputs=[ |
|
model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox |
|
], |
|
outputs=plot, |
|
) |
|
model_ids_dropdown.change( |
|
fn=partial(produce_radial_plot, results_dfs=results_dfs), |
|
inputs=[ |
|
model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox |
|
], |
|
outputs=plot, |
|
) |
|
use_win_ratio_checkbox.change( |
|
fn=partial(produce_radial_plot, results_dfs=results_dfs), |
|
inputs=[ |
|
model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox |
|
], |
|
outputs=plot, |
|
) |
|
|
|
demo.launch() |
|
|
|
|
|
def update_model_ids_dropdown( |
|
language_names: list[str], |
|
model_ids: list[str], |
|
results_dfs: dict[Language, pd.DataFrame] | None, |
|
) -> dict: |
|
"""When the language names are updated, update the model ids dropdown. |
|
|
|
Args: |
|
language_names: |
|
The names of the languages to include in the plot. |
|
model_ids: |
|
The ids of the models to include in the plot. |
|
results_dfs: |
|
The results dataframes for each language. |
|
|
|
Returns: |
|
The Gradio update to the model ids dropdown. |
|
""" |
|
global last_fetch |
|
minutes_since_last_fetch = (dt.datetime.now() - last_fetch).total_seconds() / 60 |
|
if minutes_since_last_fetch > UPDATE_FREQUENCY_MINUTES: |
|
results_dfs = fetch_results() |
|
last_fetch = dt.datetime.now() |
|
|
|
if results_dfs is None or len(language_names) == 0: |
|
if results_dfs is None: |
|
logger.info("No results fetched yet. Resetting model ids dropdown.") |
|
else: |
|
logger.info("No languages selected. Resetting model ids dropdown.") |
|
return gr.update(choices=[], value=[]) |
|
|
|
tasks = [ |
|
task |
|
for task in ALL_TASKS |
|
if all( |
|
task in df.columns |
|
for language, df in results_dfs.items() |
|
if language.name in language_names |
|
) |
|
] |
|
|
|
filtered_results_dfs = { |
|
language: df[tasks] |
|
for language, df in results_dfs.items() |
|
if language.name in language_names |
|
} |
|
|
|
unique_models = { |
|
model_id |
|
for df in filtered_results_dfs.values() |
|
for model_id in df.index |
|
} |
|
|
|
filtered_models = [ |
|
model_id |
|
for model_id in unique_models |
|
if all(model_id in df.index for df in filtered_results_dfs.values()) |
|
] |
|
|
|
if len(filtered_models) == 0: |
|
logger.info( |
|
"No valid models for the selected languages. Resetting model ids dropdown." |
|
) |
|
return gr.update(choices=[], value=[]) |
|
|
|
valid_selected_models = [ |
|
model_id for model_id in model_ids if model_id in filtered_models |
|
] |
|
if not valid_selected_models: |
|
valid_selected_models = random.sample(filtered_models, k=1) |
|
|
|
logger.info( |
|
f"Updated model ids dropdown with {len(filtered_models):,} valid models for " |
|
f"the selected languages, with {valid_selected_models} selected." |
|
) |
|
|
|
return gr.update(choices=filtered_models, value=valid_selected_models) |
|
|
|
|
|
def produce_radial_plot( |
|
model_ids: list[str], |
|
language_names: list[str], |
|
use_win_ratio: bool, |
|
results_dfs: dict[Language, pd.DataFrame] | None, |
|
) -> go.Figure: |
|
"""Produce a radial plot as a plotly figure. |
|
|
|
Args: |
|
model_ids: |
|
The ids of the models to include in the plot. |
|
language_names: |
|
The names of the languages to include in the plot. |
|
use_win_ratio: |
|
Whether to use win ratios (as opposed to raw scores). |
|
results_dfs: |
|
The results dataframes for each language. |
|
|
|
Returns: |
|
A plotly figure. |
|
""" |
|
global last_fetch |
|
minutes_since_last_fetch = (dt.datetime.now() - last_fetch).total_seconds() / 60 |
|
if minutes_since_last_fetch > UPDATE_FREQUENCY_MINUTES: |
|
results_dfs = fetch_results() |
|
last_fetch = dt.datetime.now() |
|
|
|
if results_dfs is None or len(language_names) == 0 or len(model_ids) == 0: |
|
if results_dfs is None: |
|
logger.info("No results fetched yet. Resetting plot.") |
|
elif len(language_names) == 0: |
|
logger.info("No languages selected. Resetting plot.") |
|
else: |
|
logger.info("No models selected. Resetting plot.") |
|
return go.Figure() |
|
|
|
logger.info( |
|
f"Producing radial plot for models {model_ids!r} on languages " |
|
f"{language_names!r}..." |
|
) |
|
|
|
languages = [ALL_LANGUAGES[language_name] for language_name in language_names] |
|
|
|
results_dfs_filtered = { |
|
language: df |
|
for language, df in results_dfs.items() |
|
if language.name in language_names |
|
} |
|
|
|
tasks = [ |
|
task |
|
for task in ALL_TASKS |
|
if all(task in df.columns for df in results_dfs_filtered.values()) |
|
] |
|
|
|
|
|
results: list[list[float]] = list() |
|
for model_id in model_ids: |
|
result_list = list() |
|
for task in tasks: |
|
win_ratios = list() |
|
scores = list() |
|
for language in languages: |
|
if model_id not in results_dfs_filtered[language].index: |
|
continue |
|
score = results_dfs_filtered[language].loc[model_id][task] |
|
win_ratio = 100 * np.mean([ |
|
score >= other_score |
|
for other_score in results_dfs_filtered[language][task].dropna() |
|
]) |
|
win_ratios.append(win_ratio) |
|
scores.append(score) |
|
if use_win_ratio: |
|
result_list.append(np.mean(win_ratios)) |
|
else: |
|
result_list.append(np.mean(scores)) |
|
results.append(result_list) |
|
|
|
|
|
fig = go.Figure() |
|
for model_id, result_list in zip(model_ids, results): |
|
|
|
|
|
|
|
random.seed(model_id) |
|
r, g, b = tuple(random.randint(0, 255) for _ in range(3)) |
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=result_list, |
|
theta=[task.name for task in tasks], |
|
fill='toself', |
|
name=model_id, |
|
line=dict(color=f'rgb({r}, {g}, {b})'), |
|
)) |
|
|
|
languages_str = "" |
|
if len(languages) > 1: |
|
languages_str = ", ".join([language.name for language in languages[:-1]]) |
|
languages_str += " and " |
|
languages_str += languages[-1].name |
|
|
|
if use_win_ratio: |
|
title = f'Win Ratio on on {languages_str} Language Tasks' |
|
else: |
|
title = f'LLM Score on on {languages_str} Language Tasks' |
|
|
|
|
|
fig.update_layout( |
|
polar=dict(radialaxis=dict(visible=True, range=[0, 100])), |
|
showlegend=True, |
|
title=title, |
|
width=800, |
|
) |
|
|
|
logger.info("Successfully produced radial plot.") |
|
|
|
return fig |
|
|
|
def fetch_results() -> dict[Language, pd.DataFrame]: |
|
"""Fetch the results from the ScandEval benchmark. |
|
|
|
Returns: |
|
A dictionary of languages -> results-dataframes, whose indices are the |
|
models and columns are the tasks. |
|
""" |
|
logger.info("Fetching results from ScandEval benchmark...") |
|
|
|
response = requests.get( |
|
"https://www.scandeval.com/scandeval_benchmark_results.jsonl" |
|
) |
|
response.raise_for_status() |
|
records = [ |
|
json.loads(dct_str) |
|
for dct_str in response.text.split("\n") |
|
if dct_str.strip("\n") |
|
] |
|
|
|
|
|
|
|
results_dfs = dict() |
|
for language in {dataset.language for dataset in DATASETS}: |
|
possible_dataset_names = { |
|
dataset.name for dataset in DATASETS if dataset.language == language |
|
} |
|
data_dict = defaultdict(dict) |
|
for record in records: |
|
model_name = record["model"] |
|
dataset_name = record["dataset"] |
|
if dataset_name in possible_dataset_names: |
|
dataset = next( |
|
dataset for dataset in DATASETS if dataset.name == dataset_name |
|
) |
|
results_dict = record['results']['total'] |
|
score = results_dict.get( |
|
f"test_{dataset.task.metric}", results_dict.get(dataset.task.metric) |
|
) |
|
if dataset.task in data_dict[model_name]: |
|
data_dict[model_name][dataset.task].append(score) |
|
else: |
|
data_dict[model_name][dataset.task] = [score] |
|
results_df = pd.DataFrame(data_dict).T.map( |
|
lambda list_or_nan: |
|
np.mean(list_or_nan) if list_or_nan == list_or_nan else list_or_nan |
|
).dropna() |
|
results_dfs[language] = results_df |
|
|
|
logger.info("Successfully fetched results from ScandEval benchmark.") |
|
|
|
return results_dfs |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|