Spaces:
Running
Running
Jellyfish042
commited on
Commit
·
848ffbd
1
Parent(s):
31d5e85
update
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import plotly.express as px
|
|
11 |
import plotly.graph_objects as go
|
12 |
from sklearn.linear_model import LinearRegression
|
13 |
import numpy as np
|
|
|
14 |
|
15 |
load_dotenv()
|
16 |
webhook_url = os.environ.get("WEBHOOK_URL")
|
@@ -310,37 +311,70 @@ def create_scaling_plot(all_data, period):
|
|
310 |
'Average (The lower the better)': 'Compression Rate (%)'
|
311 |
}, inplace=True)
|
312 |
|
313 |
-
|
|
|
|
|
314 |
fig = px.scatter(new_df,
|
315 |
-
x='Params(B)',
|
316 |
-
y='Compression Rate (%)',
|
317 |
title='Compression Rate Scaling Law',
|
318 |
hover_name='Name'
|
319 |
)
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
X_log = np.log(X)
|
327 |
|
328 |
-
|
329 |
-
|
|
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
334 |
|
335 |
-
|
336 |
-
x=
|
337 |
-
y=
|
338 |
mode='lines',
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
fig.update_traces(marker=dict(size=12))
|
|
|
344 |
return fig
|
345 |
|
346 |
|
|
|
11 |
import plotly.graph_objects as go
|
12 |
from sklearn.linear_model import LinearRegression
|
13 |
import numpy as np
|
14 |
+
from scipy.spatial import ConvexHull
|
15 |
|
16 |
load_dotenv()
|
17 |
webhook_url = os.environ.get("WEBHOOK_URL")
|
|
|
311 |
'Average (The lower the better)': 'Compression Rate (%)'
|
312 |
}, inplace=True)
|
313 |
|
314 |
+
new_df['Log Params(B)'] = np.log(new_df['Params(B)'])
|
315 |
+
new_df['Log Compression Rate (%)'] = np.log(new_df['Compression Rate (%)'])
|
316 |
+
|
317 |
fig = px.scatter(new_df,
|
318 |
+
x='Log Params(B)',
|
319 |
+
y='Log Compression Rate (%)',
|
320 |
title='Compression Rate Scaling Law',
|
321 |
hover_name='Name'
|
322 |
)
|
323 |
|
324 |
+
names_to_connect = ['Meta-Llama-3-8B',
|
325 |
+
'stablelm-3b-4e1t',
|
326 |
+
'stablelm-2-1_6b',
|
327 |
+
'TinyLlama-1.1B-intermediate-step-1431k-3T']
|
328 |
+
connection_points = new_df[new_df['Name'].isin(names_to_connect)]
|
329 |
+
|
330 |
+
new_df['Color'] = new_df['Name'].apply(lambda name: '#39C5BB' if name in names_to_connect else '#636efa')
|
331 |
|
332 |
+
fig.update_traces(marker=dict(color=new_df['Color']))
|
|
|
333 |
|
334 |
+
X = connection_points['Log Params(B)'].values.reshape(-1, 1)
|
335 |
+
y = connection_points['Log Compression Rate (%)'].values
|
336 |
+
model = LinearRegression().fit(X, y)
|
337 |
|
338 |
+
x_min = connection_points['Log Params(B)'].min()
|
339 |
+
x_max = connection_points['Log Params(B)'].max()
|
340 |
+
extended_x = np.linspace(x_min, x_max * 1.5, 100)
|
341 |
+
extended_x_original = np.exp(extended_x)
|
342 |
+
trend_line_y = model.predict(extended_x.reshape(-1, 1))
|
343 |
+
trend_line_y_original = np.exp(trend_line_y)
|
344 |
|
345 |
+
trend_line = go.Scatter(
|
346 |
+
x=extended_x,
|
347 |
+
y=trend_line_y,
|
348 |
mode='lines',
|
349 |
+
line=dict(color='skyblue', dash='dash'),
|
350 |
+
name='Trend Line',
|
351 |
+
hovertemplate='<b>Params(B):</b> %{customdata[0]:.2f}<br>' +
|
352 |
+
'<b>Compression Rate (%):</b> %{customdata[1]:.2f}<extra></extra>',
|
353 |
+
customdata=np.stack((extended_x_original, trend_line_y_original), axis=-1)
|
354 |
+
)
|
355 |
+
|
356 |
+
fig.add_trace(trend_line)
|
357 |
+
|
358 |
+
x_min = new_df['Params(B)'].min()
|
359 |
+
x_max = new_df['Params(B)'].max()
|
360 |
+
x_tick_vals = np.geomspace(x_min, x_max, num=5)
|
361 |
+
x_tick_text = [f"{val:.1f}" for val in x_tick_vals]
|
362 |
+
|
363 |
+
y_min = new_df['Compression Rate (%)'].min()
|
364 |
+
y_max = new_df['Compression Rate (%)'].max()
|
365 |
+
y_tick_vals = np.geomspace(y_min, y_max, num=5)
|
366 |
+
y_tick_text = [f"{val:.1f}" for val in y_tick_vals]
|
367 |
+
|
368 |
+
fig.update_xaxes(tickvals=np.log(x_tick_vals), ticktext=x_tick_text, title='Params(B)')
|
369 |
+
fig.update_yaxes(tickvals=np.log(y_tick_vals), ticktext=y_tick_text, title='Compression Rate (%)', autorange='reversed')
|
370 |
+
|
371 |
+
fig.update_layout(
|
372 |
+
xaxis=dict(showgrid=True, zeroline=False),
|
373 |
+
yaxis=dict(showgrid=True, zeroline=False)
|
374 |
+
)
|
375 |
|
376 |
fig.update_traces(marker=dict(size=12))
|
377 |
+
|
378 |
return fig
|
379 |
|
380 |
|