Jellyfish042 commited on
Commit
6d540bf
·
1 Parent(s): f276a79
Files changed (1) hide show
  1. app.py +50 -0
app.py CHANGED
@@ -8,6 +8,9 @@ from huggingface_hub.utils._errors import EntryNotFoundError, RepositoryNotFound
8
  from dotenv import load_dotenv
9
  from matplotlib.colors import LinearSegmentedColormap
10
  import plotly.express as px
 
 
 
11
 
12
  load_dotenv()
13
  webhook_url = os.environ.get("WEBHOOK_URL")
@@ -271,6 +274,29 @@ for folder in get_folders_matching_format('data'):
271
  pd.read_excel(final_file_name + '.xlsx', sheet_name=sheet_name))
272
 
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def create_scaling_plot(all_data, period):
275
  selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
276
  target_data = all_data[period]
@@ -284,12 +310,36 @@ def create_scaling_plot(all_data, period):
284
  'Average (The lower the better)': 'Compression Rate (%)'
285
  }, inplace=True)
286
 
 
287
  fig = px.scatter(new_df,
288
  x='Params(B)',
289
  y='Compression Rate (%)',
290
  title='Compression Rate Scaling Law',
291
  hover_name='Name'
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  fig.update_traces(marker=dict(size=12))
294
  return fig
295
 
 
8
  from dotenv import load_dotenv
9
  from matplotlib.colors import LinearSegmentedColormap
10
  import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from sklearn.linear_model import LinearRegression
13
+ import numpy as np
14
 
15
  load_dotenv()
16
  webhook_url = os.environ.get("WEBHOOK_URL")
 
274
  pd.read_excel(final_file_name + '.xlsx', sheet_name=sheet_name))
275
 
276
 
277
+ # def create_scaling_plot(all_data, period):
278
+ # selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
279
+ # target_data = all_data[period]
280
+ # new_df = pd.DataFrame()
281
+ #
282
+ # for size in target_data.keys():
283
+ # new_df = pd.concat([new_df, target_data[size]['cr'].loc[:, selected_columns]], axis=0)
284
+ #
285
+ # new_df.rename(columns={
286
+ # 'Parameters Count (B)': 'Params(B)',
287
+ # 'Average (The lower the better)': 'Compression Rate (%)'
288
+ # }, inplace=True)
289
+ #
290
+ # fig = px.scatter(new_df,
291
+ # x='Params(B)',
292
+ # y='Compression Rate (%)',
293
+ # title='Compression Rate Scaling Law',
294
+ # hover_name='Name'
295
+ # )
296
+ # fig.update_traces(marker=dict(size=12))
297
+ # return fig
298
+
299
+
300
  def create_scaling_plot(all_data, period):
301
  selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
302
  target_data = all_data[period]
 
310
  'Average (The lower the better)': 'Compression Rate (%)'
311
  }, inplace=True)
312
 
313
+ # Create scatter plot
314
  fig = px.scatter(new_df,
315
  x='Params(B)',
316
  y='Compression Rate (%)',
317
  title='Compression Rate Scaling Law',
318
  hover_name='Name'
319
  )
320
+
321
+ # Add logarithmic trendline
322
+ X = new_df[['Params(B)']].values
323
+ y = new_df['Compression Rate (%)'].values
324
+
325
+ # Perform log transformation on X
326
+ X_log = np.log(X)
327
+
328
+ model = LinearRegression()
329
+ model.fit(X_log, y)
330
+
331
+ # Create trendline data for plot
332
+ X_plot = np.linspace(X_log.min() - 1, X_log.max() + 0.1, 100)
333
+ y_plot = model.predict(X_plot.reshape(-1, 1))
334
+
335
+ fig.add_trace(go.Scatter(
336
+ x=np.exp(X_plot),
337
+ y=y_plot,
338
+ mode='lines',
339
+ name='Trend',
340
+ line=dict(color='#39C5BB')
341
+ ))
342
+
343
  fig.update_traces(marker=dict(size=12))
344
  return fig
345