Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns | |
import pandas as pd | |
from huggingface_hub import HfApi, create_repo | |
from datasets import Dataset, load_dataset | |
import os | |
import html | |
from src.about import ( | |
CITATION_BUTTON_LABEL, | |
CITATION_BUTTON_TEXT, | |
EVALUATION_QUEUE_TEXT, | |
INTRODUCTION_TEXT, | |
LLM_BENCHMARKS_TEXT, | |
TITLE, | |
) | |
from src.display.css_html_js import custom_css | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
if not HF_TOKEN: | |
raise ValueError("HF_TOKEN environment variable not found") | |
api = HfApi(token=HF_TOKEN) | |
DATASET_NAME = "airletters-leaderboard-results" | |
INITIAL_DATA = { | |
"name": [ | |
"ViT-B/16", "MaxViT-T", "ResNet-200", "ResNeXt-101", "SE-ResNeXt-26", | |
"ResNet-50", "VideoMAE (16)", "ResNet-101 + LSTM", "ResNet-50 + LSTM", | |
"ResNext-152 3D", "Strided Inflated EfficientNet 3D", "ResNext-50 3D", | |
"ResNext-101 3D", "ResNext-200 3D", "Video-LLaVA (w/o contrast class)", | |
"VideoLLaMA2 (w/o contrast class)", "Video-LLaVA", "VideoLLaMA2", | |
"Human Performance (10 videos/class)" | |
], | |
"url": ["https://arxiv.org/abs/2410.02921"] * 19, | |
"top1_accuracy": [ | |
7.49, 7.56, 11.44, 13.09, 13.29, 13.87, 57.96, 58.45, 63.24, | |
65.77, 65.97, 66.54, 69.74, 71.20, 2.53, 2.47, 7.29, 7.58, 96.67 | |
], | |
"organization": ["AirLetters Authors"] * 19, | |
"model_type": [ | |
"Image", "Image", "Image", "Image", "Image", | |
"Image", "Video", "Video", "Video", | |
"Video", "Video", "Video", "Video", "Video", | |
"Vision Language", "Vision Language", "Vision Language", "Vision Language", | |
"Human Evaluation" | |
] | |
} | |
def initialize_dataset(): | |
try: | |
dataset = load_dataset(f"rishitdagli/{DATASET_NAME}", split="train", token=HF_TOKEN, download_mode="force_redownload") | |
df = dataset.to_pandas() | |
if 'model_url' in df.columns: | |
df = df.map(lambda row: {"name": model_hyperlink(row["model_url"], row["name"])}) | |
df = df.drop('model_url', axis=1) | |
dataset = Dataset.from_pandas(df) | |
dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN) | |
except Exception as e: | |
print(f"Creating new dataset due to: {str(e)}") | |
df = pd.DataFrame(INITIAL_DATA) | |
dataset = Dataset.from_pandas(df) | |
try: | |
create_repo(DATASET_NAME, repo_type="dataset", token=HF_TOKEN) | |
except Exception as e: | |
print(f"Repo might already exist: {str(e)}") | |
dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN) | |
return dataset | |
def calculate_accuracy(test_file, submitted_file): | |
test = pd.read_csv(test_file) | |
test2 = pd.read_csv(submitted_file) | |
test.columns = test.columns.str.strip() | |
test2.columns = test2.columns.str.strip() | |
merged = pd.merge(test, test2, on="filename", how="left", suffixes=("_true", "_pred")) | |
merged["label_pred"] = merged["label_pred"].fillna("") | |
correct_predictions = (merged["label_true"] == merged["label_pred"]).sum() | |
total_entries = len(merged) | |
return (correct_predictions / total_entries) * 100 | |
def model_hyperlink(link, model_name): | |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>' | |
def update_leaderboard(name, organization, model_type, model_url, csv_file): | |
top1_acc = calculate_accuracy("test.csv", csv_file) | |
dataset = load_dataset(f"rishitdagli/{DATASET_NAME}", split="train", token=HF_TOKEN) | |
df = dataset.to_pandas() | |
new_row = { | |
'name': name, | |
'url': model_url, | |
'organization': organization, | |
'model_type': model_type, | |
'top1_accuracy': top1_acc, | |
} | |
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) | |
df = df.sort_values('top1_accuracy', ascending=False) | |
df = df.reset_index(drop=True) | |
new_dataset = Dataset.from_pandas(df) | |
new_dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN) | |
return "Successfully updated leaderboard" | |
def init_leaderboard(dataframe): | |
return Leaderboard( | |
value=dataframe, | |
datatype=["markdown", "str", "str", "number"], | |
select_columns=SelectColumns( | |
default_selection=["name", "organization", "model_type", "top1_accuracy"], | |
cant_deselect=["name", "top1_accuracy"], | |
label="Select Columns to Display", | |
), | |
search_columns=["name", "organization"], | |
filter_columns=[ | |
ColumnFilter("model_type", type="checkboxgroup", label="Model Type"), | |
], | |
) | |
def refresh(): | |
dataset = load_dataset(f"rishitdagli/{DATASET_NAME}", split="train", token=HF_TOKEN, download_mode="force_redownload") | |
dataset = dataset.map(lambda row: {"name": model_hyperlink(row["url"], row["name"])}) | |
df = dataset.to_pandas() | |
return df | |
def create_interface(): | |
demo = gr.Blocks(css=custom_css) | |
with demo: | |
gr.HTML(TITLE) | |
gr.Video("30fps.mp4", autoplay=True, width=900, loop=True, include_audio=False) | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("π Leaderboard", elem_id="leaderboard-tab"): | |
dataset = initialize_dataset() | |
dataset = dataset.map(lambda row: {"name": model_hyperlink(row["url"], row["name"])}) | |
df = dataset.to_pandas() | |
leaderboard = init_leaderboard(df) | |
refresh_button = gr.Button("Refresh") | |
refresh_button.click( | |
refresh, | |
inputs=[], | |
outputs=[ | |
leaderboard, | |
], | |
) | |
with gr.TabItem("π About", elem_id="about-tab"): | |
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") | |
with gr.TabItem("π Submit", elem_id="submit-tab"): | |
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") | |
with gr.Column(): | |
name = gr.Textbox(label="Model Name") | |
model_url = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/...") | |
org = gr.Textbox(label="Organization") | |
model_type = gr.Dropdown( | |
choices=["Image", "Video", "Vision Language", "Tracking", "Other"], | |
label="Model Type" | |
) | |
csv_file = gr.File(label="Results CSV", type="filepath") | |
submit_btn = gr.Button("Submit") | |
result = gr.Textbox(label="Result") | |
submit_btn.click( | |
update_leaderboard, | |
inputs=[name, org, model_type, model_url, csv_file], | |
outputs=[result] | |
) | |
with gr.Row(): | |
with gr.Accordion("π Citation", open=False): | |
citation_button = gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
label=CITATION_BUTTON_LABEL, | |
lines=7, | |
elem_id="citation-button", | |
show_copy_button=True, | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.queue().launch() |