saattrupdan's picture
feat: Change layout, fix task order, fix colours for models, fix range
76e4363
raw
history blame
17.8 kB
"""Script to produce radial plots."""
from functools import partial
import plotly.graph_objects as go
import json
import numpy as np
from collections import defaultdict
import pandas as pd
from pydantic import BaseModel
import gradio as gr
import requests
import random
import logging
import datetime as dt
fmt = "%(asctime)s [%(levelname)s] <%(name)s> %(message)s"
logging.basicConfig(level=logging.INFO, format=fmt)
logger = logging.getLogger("radial_plot_generator")
UPDATE_FREQUENCY_MINUTES = 30
class Task(BaseModel):
"""Class to hold task information."""
name: str
metric: str
def __hash__(self):
return hash(self.name)
class Language(BaseModel):
"""Class to hold language information."""
code: str
name: str
def __hash__(self):
return hash(self.code)
class Dataset(BaseModel):
"""Class to hold dataset information."""
name: str
language: Language
task: Task
def __hash__(self):
return hash(self.name)
TEXT_CLASSIFICATION = Task(name="text classification", metric="mcc")
INFORMATION_EXTRACTION = Task(name="information extraction", metric="micro_f1_no_misc")
GRAMMAR = Task(name="grammar", metric="mcc")
QUESTION_ANSWERING = Task(name="question answering", metric="em")
SUMMARISATION = Task(name="summarisation", metric="bertscore")
KNOWLEDGE = Task(name="knowledge", metric="mcc")
REASONING = Task(name="reasoning", metric="mcc")
ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)]
DANISH = Language(code="da", name="Danish")
NORWEGIAN = Language(code="no", name="Norwegian")
SWEDISH = Language(code="sv", name="Swedish")
ICELANDIC = Language(code="is", name="Icelandic")
FAROESE = Language(code="fo", name="Faroese")
GERMAN = Language(code="de", name="German")
DUTCH = Language(code="nl", name="Dutch")
ENGLISH = Language(code="en", name="English")
ALL_LANGUAGES = {
obj.name: obj for obj in globals().values() if isinstance(obj, Language)
}
DATASETS = [
Dataset(name="swerec", language=SWEDISH, task=TEXT_CLASSIFICATION),
Dataset(name="angry-tweets", language=DANISH, task=TEXT_CLASSIFICATION),
Dataset(name="norec", language=NORWEGIAN, task=TEXT_CLASSIFICATION),
Dataset(name="sb10k", language=GERMAN, task=TEXT_CLASSIFICATION),
Dataset(name="dutch-social", language=DUTCH, task=TEXT_CLASSIFICATION),
Dataset(name="sst5", language=ENGLISH, task=TEXT_CLASSIFICATION),
Dataset(name="suc3", language=SWEDISH, task=INFORMATION_EXTRACTION),
Dataset(name="dansk", language=DANISH, task=INFORMATION_EXTRACTION),
Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
Dataset(name="norne-nn", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
Dataset(name="mim-gold-ner", language=ICELANDIC, task=INFORMATION_EXTRACTION),
Dataset(name="fone", language=FAROESE, task=INFORMATION_EXTRACTION),
Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION),
Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION),
Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION),
Dataset(name="scala-sv", language=SWEDISH, task=GRAMMAR),
Dataset(name="scala-da", language=DANISH, task=GRAMMAR),
Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR),
Dataset(name="scala-nn", language=NORWEGIAN, task=GRAMMAR),
Dataset(name="scala-is", language=ICELANDIC, task=GRAMMAR),
Dataset(name="scala-fo", language=FAROESE, task=GRAMMAR),
Dataset(name="scala-de", language=GERMAN, task=GRAMMAR),
Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR),
Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR),
Dataset(name="scandiqa-da", language=DANISH, task=QUESTION_ANSWERING),
Dataset(name="norquad", language=NORWEGIAN, task=QUESTION_ANSWERING),
Dataset(name="scandiqa-sv", language=SWEDISH, task=QUESTION_ANSWERING),
Dataset(name="nqii", language=ICELANDIC, task=QUESTION_ANSWERING),
Dataset(name="germanquad", language=GERMAN, task=QUESTION_ANSWERING),
Dataset(name="squad", language=ENGLISH, task=QUESTION_ANSWERING),
Dataset(name="squad-nl", language=DUTCH, task=QUESTION_ANSWERING),
Dataset(name="nordjylland-news", language=DANISH, task=SUMMARISATION),
Dataset(name="mlsum", language=GERMAN, task=SUMMARISATION),
Dataset(name="rrn", language=ICELANDIC, task=SUMMARISATION),
Dataset(name="no-sammendrag", language=NORWEGIAN, task=SUMMARISATION),
Dataset(name="wiki-lingua-nl", language=DUTCH, task=SUMMARISATION),
Dataset(name="swedn", language=SWEDISH, task=SUMMARISATION),
Dataset(name="cnn-dailymail", language=ENGLISH, task=SUMMARISATION),
Dataset(name="mmlu-da", language=DANISH, task=KNOWLEDGE),
Dataset(name="mmlu-no", language=NORWEGIAN, task=KNOWLEDGE),
Dataset(name="mmlu-sv", language=SWEDISH, task=KNOWLEDGE),
Dataset(name="mmlu-is", language=ICELANDIC, task=KNOWLEDGE),
Dataset(name="mmlu-de", language=GERMAN, task=KNOWLEDGE),
Dataset(name="mmlu-nl", language=DUTCH, task=KNOWLEDGE),
Dataset(name="mmlu", language=ENGLISH, task=KNOWLEDGE),
Dataset(name="arc-da", language=DANISH, task=KNOWLEDGE),
Dataset(name="arc-no", language=NORWEGIAN, task=KNOWLEDGE),
Dataset(name="arc-sv", language=SWEDISH, task=KNOWLEDGE),
Dataset(name="arc-is", language=ICELANDIC, task=KNOWLEDGE),
Dataset(name="arc-de", language=GERMAN, task=KNOWLEDGE),
Dataset(name="arc-nl", language=DUTCH, task=KNOWLEDGE),
Dataset(name="arc", language=ENGLISH, task=KNOWLEDGE),
Dataset(name="hellaswag-da", language=DANISH, task=REASONING),
Dataset(name="hellaswag-no", language=NORWEGIAN, task=REASONING),
Dataset(name="hellaswag-sv", language=SWEDISH, task=REASONING),
Dataset(name="hellaswag-is", language=ICELANDIC, task=REASONING),
Dataset(name="hellaswag-de", language=GERMAN, task=REASONING),
Dataset(name="hellaswag-nl", language=DUTCH, task=REASONING),
Dataset(name="hellaswag", language=ENGLISH, task=REASONING),
]
def main() -> None:
"""Produce a radial plot."""
global last_fetch
results_dfs = fetch_results()
last_fetch = dt.datetime.now()
all_languages = [
language.name for language in ALL_LANGUAGES.values()
]
danish_models = list({
model_id
for model_id in results_dfs[DANISH].index
})
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# Radial Plot Generator")
gr.Markdown(
"This demo allows you to generate a radial plot comparing the performance "
"of different language models on different tasks. It is based on the "
"generative results from the [ScandEval benchmark](https://scandeval.com)."
)
with gr.Column():
with gr.Row():
language_names_dropdown = gr.Dropdown(
choices=all_languages,
multiselect=True,
label="Languages",
value=["Danish"],
interactive=True,
scale=2,
)
model_ids_dropdown = gr.Dropdown(
choices=danish_models,
multiselect=True,
label="Models",
value=["gpt-4-0613", "mistralai/Mistral-7B-v0.1"],
interactive=True,
scale=2,
)
use_win_ratio_checkbox = gr.Checkbox(
label="Compare models with win ratios (as opposed to raw scores)",
value=True,
interactive=True,
scale=1,
)
with gr.Row():
plot = gr.Plot(
value=produce_radial_plot(
model_ids_dropdown.value,
language_names=language_names_dropdown.value,
use_win_ratio=use_win_ratio_checkbox.value,
results_dfs=results_dfs,
),
)
with gr.Row():
gr.Markdown(
"<center>Made with ❤️ by the <a href=\"https://alexandra.dk\">"
"Alexandra Institute</a>.</center>"
)
language_names_dropdown.change(
fn=partial(update_model_ids_dropdown, results_dfs=results_dfs),
inputs=[language_names_dropdown, model_ids_dropdown],
outputs=model_ids_dropdown,
)
# Update plot when anything changes
language_names_dropdown.change(
fn=partial(produce_radial_plot, results_dfs=results_dfs),
inputs=[
model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox
],
outputs=plot,
)
model_ids_dropdown.change(
fn=partial(produce_radial_plot, results_dfs=results_dfs),
inputs=[
model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox
],
outputs=plot,
)
use_win_ratio_checkbox.change(
fn=partial(produce_radial_plot, results_dfs=results_dfs),
inputs=[
model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox
],
outputs=plot,
)
demo.launch()
def update_model_ids_dropdown(
language_names: list[str],
model_ids: list[str],
results_dfs: dict[Language, pd.DataFrame] | None,
) -> dict:
"""When the language names are updated, update the model ids dropdown.
Args:
language_names:
The names of the languages to include in the plot.
model_ids:
The ids of the models to include in the plot.
results_dfs:
The results dataframes for each language.
Returns:
The Gradio update to the model ids dropdown.
"""
global last_fetch
minutes_since_last_fetch = (dt.datetime.now() - last_fetch).total_seconds() / 60
if minutes_since_last_fetch > UPDATE_FREQUENCY_MINUTES:
results_dfs = fetch_results()
last_fetch = dt.datetime.now()
if results_dfs is None or len(language_names) == 0:
if results_dfs is None:
logger.info("No results fetched yet. Resetting model ids dropdown.")
else:
logger.info("No languages selected. Resetting model ids dropdown.")
return gr.update(choices=[], value=[])
tasks = [
task
for task in ALL_TASKS
if all(
task in df.columns
for language, df in results_dfs.items()
if language.name in language_names
)
]
filtered_results_dfs = {
language: df[tasks]
for language, df in results_dfs.items()
if language.name in language_names
}
unique_models = {
model_id
for df in filtered_results_dfs.values()
for model_id in df.index
}
filtered_models = [
model_id
for model_id in unique_models
if all(model_id in df.index for df in filtered_results_dfs.values())
]
if len(filtered_models) == 0:
logger.info(
"No valid models for the selected languages. Resetting model ids dropdown."
)
return gr.update(choices=[], value=[])
valid_selected_models = [
model_id for model_id in model_ids if model_id in filtered_models
]
if not valid_selected_models:
valid_selected_models = random.sample(filtered_models, k=1)
logger.info(
f"Updated model ids dropdown with {len(filtered_models):,} valid models for "
f"the selected languages, with {valid_selected_models} selected."
)
return gr.update(choices=filtered_models, value=valid_selected_models)
def produce_radial_plot(
model_ids: list[str],
language_names: list[str],
use_win_ratio: bool,
results_dfs: dict[Language, pd.DataFrame] | None,
) -> go.Figure:
"""Produce a radial plot as a plotly figure.
Args:
model_ids:
The ids of the models to include in the plot.
language_names:
The names of the languages to include in the plot.
use_win_ratio:
Whether to use win ratios (as opposed to raw scores).
results_dfs:
The results dataframes for each language.
Returns:
A plotly figure.
"""
global last_fetch
minutes_since_last_fetch = (dt.datetime.now() - last_fetch).total_seconds() / 60
if minutes_since_last_fetch > UPDATE_FREQUENCY_MINUTES:
results_dfs = fetch_results()
last_fetch = dt.datetime.now()
if results_dfs is None or len(language_names) == 0 or len(model_ids) == 0:
if results_dfs is None:
logger.info("No results fetched yet. Resetting plot.")
elif len(language_names) == 0:
logger.info("No languages selected. Resetting plot.")
else:
logger.info("No models selected. Resetting plot.")
return go.Figure()
logger.info(
f"Producing radial plot for models {model_ids!r} on languages "
f"{language_names!r}..."
)
languages = [ALL_LANGUAGES[language_name] for language_name in language_names]
results_dfs_filtered = {
language: df
for language, df in results_dfs.items()
if language.name in language_names
}
tasks = [
task
for task in ALL_TASKS
if all(task in df.columns for df in results_dfs_filtered.values())
]
# Add all the evaluation results for each model
results: list[list[float]] = list()
for model_id in model_ids:
result_list = list()
for task in tasks:
win_ratios = list()
scores = list()
for language in languages:
if model_id not in results_dfs_filtered[language].index:
continue
score = results_dfs_filtered[language].loc[model_id][task]
win_ratio = 100 * np.mean([
score >= other_score
for other_score in results_dfs_filtered[language][task].dropna()
])
win_ratios.append(win_ratio)
scores.append(score)
if use_win_ratio:
result_list.append(np.mean(win_ratios))
else:
result_list.append(np.mean(scores))
results.append(result_list)
# Add the results to a plotly figure
fig = go.Figure()
for model_id, result_list in zip(model_ids, results):
# Generate colour for model, as an RGB triplet. The same model will always
# have the same colour
random.seed(model_id)
r, g, b = tuple(random.randint(0, 255) for _ in range(3))
fig.add_trace(go.Scatterpolar(
r=result_list,
theta=[task.name for task in tasks],
fill='toself',
name=model_id,
line=dict(color=f'rgb({r}, {g}, {b})'),
))
languages_str = ""
if len(languages) > 1:
languages_str = ", ".join([language.name for language in languages[:-1]])
languages_str += " and "
languages_str += languages[-1].name
if use_win_ratio:
title = f'Win Ratio on on {languages_str} Language Tasks'
else:
title = f'LLM Score on on {languages_str} Language Tasks'
# Builds the radial plot from the results
fig.update_layout(
polar=dict(radialaxis=dict(visible=True, range=[0, 100])),
showlegend=True,
title=title,
width=800,
)
logger.info("Successfully produced radial plot.")
return fig
def fetch_results() -> dict[Language, pd.DataFrame]:
"""Fetch the results from the ScandEval benchmark.
Returns:
A dictionary of languages -> results-dataframes, whose indices are the
models and columns are the tasks.
"""
logger.info("Fetching results from ScandEval benchmark...")
response = requests.get(
"https://www.scandeval.com/scandeval_benchmark_results.jsonl"
)
response.raise_for_status()
records = [
json.loads(dct_str)
for dct_str in response.text.split("\n")
if dct_str.strip("\n")
]
# Build a dictionary of languages -> results-dataframes, whose indices are the
# models and columns are the tasks.
results_dfs = dict()
for language in {dataset.language for dataset in DATASETS}:
possible_dataset_names = {
dataset.name for dataset in DATASETS if dataset.language == language
}
data_dict = defaultdict(dict)
for record in records:
model_name = record["model"]
dataset_name = record["dataset"]
if dataset_name in possible_dataset_names:
dataset = next(
dataset for dataset in DATASETS if dataset.name == dataset_name
)
results_dict = record['results']['total']
score = results_dict.get(
f"test_{dataset.task.metric}", results_dict.get(dataset.task.metric)
)
if dataset.task in data_dict[model_name]:
data_dict[model_name][dataset.task].append(score)
else:
data_dict[model_name][dataset.task] = [score]
results_df = pd.DataFrame(data_dict).T.map(
lambda list_or_nan:
np.mean(list_or_nan) if list_or_nan == list_or_nan else list_or_nan
).dropna()
results_dfs[language] = results_df
logger.info("Successfully fetched results from ScandEval benchmark.")
return results_dfs
if __name__ == "__main__":
main()