import gradio as gr # type: ignore import plotly.express as px # type: ignore from backend.data import load_cot_data from backend.envs import API, REPO_ID, TOKEN logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png" logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png" LOGOS = f'
AI2 Logikon AI
' TITLE = f'

Open CoT Dashboard

{LOGOS}' INTRODUCTION_TEXT = """ Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard). """ def restart_space(): API.restart_space(repo_id=REPO_ID, token=TOKEN) try: df_cot_err, df_cot_regimes = load_cot_data() except Exception as err: print(err) # sleep for 10 seconds before restarting the space import time time.sleep(10) restart_space() def plot_evals_init(model_id, regex_model_filter, plotly_mode, request: gr.Request): if request and "model" in request.query_params: model_param = request.query_params["model"] if model_param in df_cot_err.model.to_list(): model_id = model_param return plot_evals(model_id, regex_model_filter, plotly_mode) def plot_evals(model_id, regex_model_filter, plotly_mode): df = df_cot_err.copy() df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-") try: df_filter = df.model.str.contains(regex_model_filter) except Exception as err: gr.Warning("Failed to apply regex filter", duration=4) print("Failed to apply regex filter" + err) df_filter = df.model.str.contains(".*") df = df[df_filter | df.selected.eq("selected")] #df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter template = "plotly_dark" if plotly_mode=="dark" else "plotly" fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model", facet_col="task", facet_col_wrap=3, category_orders={"selected": ["selected", "-"]}, color_discrete_sequence=["Orange", "Gray"], template=template, error_y="acc_gain-err", hover_data=['model', "cot accuracy"], width=1200, height=700) fig.update_layout( title={"automargin": True}, ) return fig, model_id def styled_model_table_init(model_id, request: gr.Request): if request and "model" in request.query_params: model_param = request.query_params["model"] if model_param in df_cot_regimes.model.to_list(): model_id = model_param return styled_model_table(model_id) def styled_model_table(model_id): def make_pretty(styler): styler.hide(axis="index") styler.format(precision=1), styler.background_gradient( axis=None, subset=["acc_base", "acc_cot"], vmin=20, vmax=100, cmap="YlGnBu" ) styler.background_gradient( axis=None, subset=["acc_gain"], vmin=-20, vmax=20, cmap="coolwarm" ) styler.set_table_styles({ 'task': [{'selector': '', 'props': [('font-weight', 'bold')]}], 'B': [{'selector': 'td', 'props': 'color: blue;'}] }, overwrite=False) return styler df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of', 'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'acc_gain']] df_cot_model = df_cot_model \ .rename(columns={"temperature": "temp"}) \ .replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \ .sort_values(["task", "cot_chain"]) \ .reset_index(drop=True) return df_cot_model.style.pipe(make_pretty) demo = gr.Blocks() with demo: gr.HTML(TITLE) gr.Markdown(INTRODUCTION_TEXT) with gr.Row(): selected_model = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", info="with performance details below", scale=2) regex_model_filter = gr.Textbox(".*", label="Regex", info="to filter models shown in plots", scale=2) plotly_mode = gr.Radio(["dark","light"], value="light", label="Theme", info="of plots", scale=1) submit = gr.Button("Update", scale=1) table = gr.DataFrame() plot = gr.Plot(label="evals") submit.click(plot_evals, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model]) submit.click(styled_model_table, selected_model, table) demo.load(plot_evals_init, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model]) demo.load(styled_model_table_init, selected_model, table) demo.launch()