File size: 4,541 Bytes
13e8963 c28665f 13e8963 c28665f 13e8963 c28665f 13e8963 c28665f 13e8963 ca2e2c2 13e8963 c28665f 13e8963 c28665f 13e8963 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr # type: ignore
import plotly.express as px # type: ignore
from backend.data import load_cot_data
from backend.envs import API, REPO_ID, TOKEN
logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png"
LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>'
TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}'
INTRODUCTION_TEXT = """
Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard).
"""
def restart_space():
API.restart_space(repo_id=REPO_ID, token=TOKEN)
try:
df_cot_err, df_cot_regimes = load_cot_data()
except Exception:
restart_space()
def plot_evals_init(model_id, plotly_mode, request: gr.Request):
if request and "model" in request.query_params:
model_param = request.query_params["model"]
if model_param in df_cot_err.model.to_list():
model_id = model_param
return plot_evals(model_id, plotly_mode)
def plot_evals(model_id, plotly_mode):
df = df_cot_err.copy()
df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
#df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
template = "plotly_dark" if plotly_mode=="dark" else "plotly"
fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
facet_col="task", facet_col_wrap=3,
category_orders={"selected": ["selected", "-"]},
color_discrete_sequence=["Orange", "Gray"],
template=template,
error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
width=1200, height=700)
fig.update_layout(
title={"automargin": True},
)
return fig, model_id
def styled_model_table_init(model_id, request: gr.Request):
if request and "model" in request.query_params:
model_param = request.query_params["model"]
if model_param in df_cot_regimes.model.to_list():
model_id = model_param
return styled_model_table(model_id)
def styled_model_table(model_id):
def make_pretty(styler):
styler.hide(axis="index")
styler.format(precision=1),
styler.background_gradient(
axis=None,
subset=["acc_base", "acc_cot"],
vmin=20, vmax=100, cmap="YlGnBu"
)
styler.background_gradient(
axis=None,
subset=["acc_gain"],
vmin=-20, vmax=20, cmap="coolwarm"
)
styler.set_table_styles({
'task': [{'selector': '',
'props': [('font-weight', 'bold')]}],
'B': [{'selector': 'td',
'props': 'color: blue;'}]
}, overwrite=False)
return styler
df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of',
'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'acc_gain']]
df_cot_model = df_cot_model \
.rename(columns={"temperature": "temp"}) \
.replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \
.sort_values(["task", "cot_chain"]) \
.reset_index(drop=True)
return df_cot_model.style.pipe(make_pretty)
demo = gr.Blocks()
with demo:
gr.HTML(TITLE)
gr.Markdown(INTRODUCTION_TEXT)
with gr.Row():
model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
submit = gr.Button("Update", scale=1)
table = gr.DataFrame()
plot = gr.Plot(label="evals")
submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list])
submit.click(styled_model_table, model_list, table)
demo.load(plot_evals_init, [model_list, plotly_mode], [plot, model_list])
demo.load(styled_model_table_init, model_list, table)
demo.launch() |