Shanshan Wang
color change
6cdf45d
raw
history blame
5.48 kB
import pandas as pd
import plotly.express as px
import gradio as gr
data_path = '0926-OCRBench-opensource.csv'
data_mmlm_path = 'filtered_opencompass.csv'
data = pd.read_csv(data_path).fillna(0)
######## OCRBench ########
# set the data types for the columns
dtype_dict = {
"Model": str,
"Param (B)": float,
"OCRBench":int,
"Text Recognition":int,
"Scene Text-centric VQA":int,
"Document Oriented VQA":int,
"KIE":int,
"Handwritten Math Expression Recognition":int}
# preprocess the dataframe
data_valid = data[:25].copy()
data_valid = data_valid.astype(dtype_dict)
data_valid.drop(columns=['Unnamed: 11'], inplace=True)
# Add a new column that assigns categories to Model A, Model B, and Model C, and 'Other' to the rest
def categorize_model(model):
if model in ["H2OVL-Mississippi-2B", "H2OVL-Mississippi-0.8B"]:
return "H2OVLs"
elif model.startswith("doctr"): # Third group for ocr models
return "traditional ocr models"
else:
return "Other"
# Define a color map with yellow for "H2OVLs"
color_map = {"H2OVLs": "#FFE600", "Other": "#9F9F9D", "traditional ocr models": "#54585A"}
# Apply the categorization to create a new column
data_valid["Category"] = data_valid["Model"].apply(categorize_model)
# ploting
def plot_metric(selected_metric):
filtered_data = data_valid[data_valid[selected_metric] !=0 ]
# Create the scatter plot with different colors for "Special" and "Other"
fig = px.scatter(
filtered_data,
x="Param (B)",
y=selected_metric,
text="Model",
color="Category", # Different color for Special and Other categories
title=f"{selected_metric} vs Model Size",
color_discrete_map=color_map
)
fig.update_traces(marker=dict(size=10), mode='markers+text', textposition="middle right", textfont=dict(size=10))
# Extend the x-axis range
max_x_value = filtered_data["Param (B)"].max()
fig.update_layout(
xaxis_range=[0, max_x_value + 5], # Extend the x-axis range to give more space for text
xaxis_title="Model Size (B)",
yaxis_title=selected_metric,
showlegend=False,
height=800,
margin=dict(t=50, l=50, r=100, b=50), # Increase right margin for more space
)
# Use texttemplate to ensure full model name is displayed
fig.update_traces(texttemplate='%{text}')
return fig
####### OpenCompass ########
data_mmlm = pd.read_csv(data_mmlm_path).fillna(0)
data_mmlm.rename(columns={"Avg. Score (8 single image benchmarks)": "Average Score"}, inplace=True)
metrics_column = list(data_mmlm.columns)[6:]
def plot_metric_mmlm_grouped(category):
# Filter the data based on the selected category
filtered_data = data_mmlm[data_mmlm["Category"] == category].copy()
# Melt the dataframe to have a "Metric" column and a "Score" column
melted_data = pd.melt(
filtered_data,
id_vars=["Models"], # Keep the Model column as identifier
value_vars=metrics_column, # Melt all the metric columns
var_name="Metrics", # Name for the new column containing metrics
value_name="Score" # Name for the new column containing scores
)
# Generate a grouped bar chart
fig = px.bar(
melted_data,
x="Metrics",
y="Score",
color="Models", # Differentiate metrics by color
barmode="group", # Grouped bars
title=f"Scores for All Metrics in {category} Category"
)
fig.update_layout(
xaxis_title="Metrics",
yaxis_title="Score",
height=600,
margin=dict(t=50, l=50, r=100, b=50),
)
return fig
# Gradio Blocks Interface with Tabs
def create_interface():
with gr.Blocks() as interface:
with gr.Tabs():
with gr.Tab("OCRBench"):
with gr.Row():
with gr.Column(scale=4): # Column for the plot (takes 4 parts of the total space)
plot = gr.Plot(value=plot_metric("Text Recognition"), label="OCR Benchmark Metrics") # default plot component initially
with gr.Column(scale=1): # Column for the dropdown (takes 1 part of the total space)
metrics = list(data_valid.columns[5:-1]) # List of metric columns (excluding 'Model' and 'Parameter Size')
dropdown = gr.Dropdown(metrics, label="Select Metric", value="Text Recognition")
# Update the plot when dropdown selection changes
dropdown.change(fn=plot_metric, inputs=dropdown, outputs=plot)
with gr.Tab("8 Multi-modal Benchmarks"):
with gr.Row():
# Dropdown for selecting the category
categories = data_mmlm["Category"].unique().tolist()
category_dropdown = gr.Dropdown(categories, label="Select Category", value=categories[0])
with gr.Row():
mm_plot = gr.Plot(value=plot_metric_mmlm_grouped(categories[0]), label="Grouped Metrics for Models")
# Update the plot based on category dropdown changes
category_dropdown.change(fn=plot_metric_mmlm_grouped, inputs=category_dropdown, outputs=mm_plot)
return interface
# Launch the interface
if __name__ == "__main__":
create_interface().launch()