Jellyfish042 commited on
Commit
848ffbd
·
1 Parent(s): 31d5e85
Files changed (1) hide show
  1. app.py +53 -19
app.py CHANGED
@@ -11,6 +11,7 @@ import plotly.express as px
11
  import plotly.graph_objects as go
12
  from sklearn.linear_model import LinearRegression
13
  import numpy as np
 
14
 
15
  load_dotenv()
16
  webhook_url = os.environ.get("WEBHOOK_URL")
@@ -310,37 +311,70 @@ def create_scaling_plot(all_data, period):
310
  'Average (The lower the better)': 'Compression Rate (%)'
311
  }, inplace=True)
312
 
313
- # Create scatter plot
 
 
314
  fig = px.scatter(new_df,
315
- x='Params(B)',
316
- y='Compression Rate (%)',
317
  title='Compression Rate Scaling Law',
318
  hover_name='Name'
319
  )
320
 
321
- # Add logarithmic trendline
322
- X = new_df[['Params(B)']].values
323
- y = new_df['Compression Rate (%)'].values
 
 
 
 
324
 
325
- # Perform log transformation on X
326
- X_log = np.log(X)
327
 
328
- model = LinearRegression()
329
- model.fit(X_log, y)
 
330
 
331
- # Create trendline data for plot
332
- X_plot = np.linspace(X_log.min() - 1, X_log.max() + 0.1, 100)
333
- y_plot = model.predict(X_plot.reshape(-1, 1))
 
 
 
334
 
335
- fig.add_trace(go.Scatter(
336
- x=np.exp(X_plot),
337
- y=y_plot,
338
  mode='lines',
339
- name='Trend',
340
- line=dict(color='#39C5BB')
341
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  fig.update_traces(marker=dict(size=12))
 
344
  return fig
345
 
346
 
 
11
  import plotly.graph_objects as go
12
  from sklearn.linear_model import LinearRegression
13
  import numpy as np
14
+ from scipy.spatial import ConvexHull
15
 
16
  load_dotenv()
17
  webhook_url = os.environ.get("WEBHOOK_URL")
 
311
  'Average (The lower the better)': 'Compression Rate (%)'
312
  }, inplace=True)
313
 
314
+ new_df['Log Params(B)'] = np.log(new_df['Params(B)'])
315
+ new_df['Log Compression Rate (%)'] = np.log(new_df['Compression Rate (%)'])
316
+
317
  fig = px.scatter(new_df,
318
+ x='Log Params(B)',
319
+ y='Log Compression Rate (%)',
320
  title='Compression Rate Scaling Law',
321
  hover_name='Name'
322
  )
323
 
324
+ names_to_connect = ['Meta-Llama-3-8B',
325
+ 'stablelm-3b-4e1t',
326
+ 'stablelm-2-1_6b',
327
+ 'TinyLlama-1.1B-intermediate-step-1431k-3T']
328
+ connection_points = new_df[new_df['Name'].isin(names_to_connect)]
329
+
330
+ new_df['Color'] = new_df['Name'].apply(lambda name: '#39C5BB' if name in names_to_connect else '#636efa')
331
 
332
+ fig.update_traces(marker=dict(color=new_df['Color']))
 
333
 
334
+ X = connection_points['Log Params(B)'].values.reshape(-1, 1)
335
+ y = connection_points['Log Compression Rate (%)'].values
336
+ model = LinearRegression().fit(X, y)
337
 
338
+ x_min = connection_points['Log Params(B)'].min()
339
+ x_max = connection_points['Log Params(B)'].max()
340
+ extended_x = np.linspace(x_min, x_max * 1.5, 100)
341
+ extended_x_original = np.exp(extended_x)
342
+ trend_line_y = model.predict(extended_x.reshape(-1, 1))
343
+ trend_line_y_original = np.exp(trend_line_y)
344
 
345
+ trend_line = go.Scatter(
346
+ x=extended_x,
347
+ y=trend_line_y,
348
  mode='lines',
349
+ line=dict(color='skyblue', dash='dash'),
350
+ name='Trend Line',
351
+ hovertemplate='<b>Params(B):</b> %{customdata[0]:.2f}<br>' +
352
+ '<b>Compression Rate (%):</b> %{customdata[1]:.2f}<extra></extra>',
353
+ customdata=np.stack((extended_x_original, trend_line_y_original), axis=-1)
354
+ )
355
+
356
+ fig.add_trace(trend_line)
357
+
358
+ x_min = new_df['Params(B)'].min()
359
+ x_max = new_df['Params(B)'].max()
360
+ x_tick_vals = np.geomspace(x_min, x_max, num=5)
361
+ x_tick_text = [f"{val:.1f}" for val in x_tick_vals]
362
+
363
+ y_min = new_df['Compression Rate (%)'].min()
364
+ y_max = new_df['Compression Rate (%)'].max()
365
+ y_tick_vals = np.geomspace(y_min, y_max, num=5)
366
+ y_tick_text = [f"{val:.1f}" for val in y_tick_vals]
367
+
368
+ fig.update_xaxes(tickvals=np.log(x_tick_vals), ticktext=x_tick_text, title='Params(B)')
369
+ fig.update_yaxes(tickvals=np.log(y_tick_vals), ticktext=y_tick_text, title='Compression Rate (%)', autorange='reversed')
370
+
371
+ fig.update_layout(
372
+ xaxis=dict(showgrid=True, zeroline=False),
373
+ yaxis=dict(showgrid=True, zeroline=False)
374
+ )
375
 
376
  fig.update_traces(marker=dict(size=12))
377
+
378
  return fig
379
 
380