Jellyfish042 commited on
Commit
f276a79
Β·
1 Parent(s): c96965d
Files changed (2) hide show
  1. app.py +41 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -7,6 +7,7 @@ import huggingface_hub
7
  from huggingface_hub.utils._errors import EntryNotFoundError, RepositoryNotFoundError
8
  from dotenv import load_dotenv
9
  from matplotlib.colors import LinearSegmentedColormap
 
10
 
11
  load_dotenv()
12
  webhook_url = os.environ.get("WEBHOOK_URL")
@@ -269,8 +270,34 @@ for folder in get_folders_matching_format('data'):
269
  all_data[folder_name][file_name][sheet_name] = rename_columns(
270
  pd.read_excel(final_file_name + '.xlsx', sheet_name=sheet_name))
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  initial_period = time_list[-1]
273
- initial_models = model_size_list[:1]
274
  initial_metric = metric_list[0]
275
  initial_columns = get_unique_column_names(all_data)
276
  initial_colors = ['Average']
@@ -300,7 +327,7 @@ with gr.Blocks(css=css) as demo:
300
  with gr.Row():
301
  with gr.Column():
302
  period_selector = gr.Dropdown(label="Period", choices=time_list, value=time_list[0])
303
- model_selector = gr.CheckboxGroup(label="Model", choices=model_size_list, value=model_size_list[0])
304
  metric_selector = gr.Dropdown(label="Metric", choices=metric_list, value=metric_list[0])
305
  with gr.Column():
306
  color_selector = gr.CheckboxGroup(label="Colored Columns",
@@ -312,7 +339,7 @@ with gr.Blocks(css=css) as demo:
312
 
313
  table = gr.Dataframe(initial_data, column_widths=[130, 60, 60, 35, 35, 40, 40, 35, 35, 35],
314
  wrap=True,
315
- height=800
316
  )
317
 
318
  period_selector.change(update_table,
@@ -333,6 +360,17 @@ with gr.Blocks(css=css) as demo:
333
 
334
  with gr.Tab("🌍 MultiLang"):
335
  gr.Markdown("## Coming soon...")
 
 
 
 
 
 
 
 
 
 
 
336
  with gr.Tab("ℹ️ About"):
337
  gr.Markdown(about_md)
338
  with gr.Tab("πŸš€ Submit"):
 
7
  from huggingface_hub.utils._errors import EntryNotFoundError, RepositoryNotFoundError
8
  from dotenv import load_dotenv
9
  from matplotlib.colors import LinearSegmentedColormap
10
+ import plotly.express as px
11
 
12
  load_dotenv()
13
  webhook_url = os.environ.get("WEBHOOK_URL")
 
270
  all_data[folder_name][file_name][sheet_name] = rename_columns(
271
  pd.read_excel(final_file_name + '.xlsx', sheet_name=sheet_name))
272
 
273
+
274
+ def create_scaling_plot(all_data, period):
275
+ selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
276
+ target_data = all_data[period]
277
+ new_df = pd.DataFrame()
278
+
279
+ for size in target_data.keys():
280
+ new_df = pd.concat([new_df, target_data[size]['cr'].loc[:, selected_columns]], axis=0)
281
+
282
+ new_df.rename(columns={
283
+ 'Parameters Count (B)': 'Params(B)',
284
+ 'Average (The lower the better)': 'Compression Rate (%)'
285
+ }, inplace=True)
286
+
287
+ fig = px.scatter(new_df,
288
+ x='Params(B)',
289
+ y='Compression Rate (%)',
290
+ title='Compression Rate Scaling Law',
291
+ hover_name='Name'
292
+ )
293
+ fig.update_traces(marker=dict(size=12))
294
+ return fig
295
+
296
+
297
+ initial_fig = create_scaling_plot(all_data, time_list[-1])
298
+
299
  initial_period = time_list[-1]
300
+ initial_models = model_size_list
301
  initial_metric = metric_list[0]
302
  initial_columns = get_unique_column_names(all_data)
303
  initial_colors = ['Average']
 
327
  with gr.Row():
328
  with gr.Column():
329
  period_selector = gr.Dropdown(label="Period", choices=time_list, value=time_list[0])
330
+ model_selector = gr.CheckboxGroup(label="Model", choices=model_size_list, value=model_size_list)
331
  metric_selector = gr.Dropdown(label="Metric", choices=metric_list, value=metric_list[0])
332
  with gr.Column():
333
  color_selector = gr.CheckboxGroup(label="Colored Columns",
 
339
 
340
  table = gr.Dataframe(initial_data, column_widths=[130, 60, 60, 35, 35, 40, 40, 35, 35, 35],
341
  wrap=True,
342
+ height=800,
343
  )
344
 
345
  period_selector.change(update_table,
 
360
 
361
  with gr.Tab("🌍 MultiLang"):
362
  gr.Markdown("## Coming soon...")
363
+ with gr.Tab("πŸ“ˆ Scaling Law"):
364
+ period_selector_2 = gr.Dropdown(label="Period", choices=time_list, value=time_list[0])
365
+
366
+ def update_plot(period):
367
+ new_fig = create_scaling_plot(all_data, period)
368
+ return new_fig
369
+
370
+
371
+ plot = gr.Plot(initial_fig)
372
+ period_selector_2.change(update_plot, inputs=period_selector_2, outputs=plot)
373
+
374
  with gr.Tab("ℹ️ About"):
375
  gr.Markdown(about_md)
376
  with gr.Tab("πŸš€ Submit"):
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- openpyxl
 
 
1
+ openpyxl
2
+ plotly