wenjiao commited on
Commit
2349769
Β·
verified Β·
1 Parent(s): 625f2d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -55
app.py CHANGED
@@ -51,6 +51,7 @@ import plotly.graph_objects as go
51
 
52
  selected_indices = []
53
  selected_values = {}
 
54
 
55
  # Start ephemeral Spaces on PRs (see config in README.md)
56
  #enable_space_ci()
@@ -59,52 +60,160 @@ precision_to_dtype = {
59
  "2bit": ["int2"],
60
  "3bit": ["int3"],
61
  "4bit": ["int4", "nf4", "fp4"],
62
- "?": ["?"]
 
 
 
63
  }
64
 
65
- current_weightDtype = ["All", "int2", "int3", "int4", "nf4", "fp4", "?"]
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Global variable to store the selected dtypes
68
- selected_dtypes = ["All"]
69
- init_select = False
 
70
 
71
- def quant_update_Weight_Dtype(selected_precisions):
72
- global current_weightDtype
73
- if 'βœ– None' in selected_precisions:
74
- if not any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in current_weightDtype):
75
- current_weightDtype += ['float16', 'bfloat16', 'float32']
76
- else:
77
- if any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in current_weightDtype):
78
- current_weightDtype = [dtype for dtype in current_weightDtype if dtype not in ['float16', 'bfloat16', 'float32']]
79
- return gr.Dropdown(choices=current_weightDtype, value="All")
80
 
 
 
 
81
 
82
- def update_Weight_Dtype(selected_precisions):
83
- global selected_dtypes
 
 
 
84
  global current_weightDtype
85
- global init_select
86
- init_select = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- if not selected_precisions: # If no precision is selected, return "All"
89
- selected_dtypes = ["All"]
90
- return gr.Dropdown(choices=["All"], value="All")
 
 
 
91
 
92
- selected_dtypes_set = set()
93
- for precision in selected_precisions:
94
- if precision in precision_to_dtype:
95
- selected_dtypes_set.update(precision_to_dtype[precision])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Convert set to sorted list to maintain order
99
- selected_dtypes = sorted(selected_dtypes_set)
100
- if any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in current_weightDtype) and not any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in selected_dtypes):
101
- selected_dtypes += ['float16', 'bfloat16', 'float32']
102
- # Add "All" to the beginning of the list for display purposes
103
- display_choices = ["All"] + selected_dtypes
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- current_weightDtype = display_choices
107
- return gr.Dropdown(choices=display_choices, value="All")
108
 
109
 
110
 
@@ -177,18 +286,15 @@ def update_table(
177
  ):
178
  global init_select
179
  global current_weightDtype
 
180
 
181
-
182
- if selected_dtypes == ['All']:
183
  weight_dtype = current_weightDtype
184
- elif weight_dtype == ['All'] or weight_dtype == 'All' or init_select:
185
- weight_dtype = selected_dtypes
186
- init_select = False
187
  else:
188
  weight_dtype = [weight_dtype]
189
 
190
  if compute_dtype == 'All':
191
- compute_dtype = ['bfloat16', 'float16', 'int8', 'float32']
192
  else:
193
  compute_dtype = [compute_dtype]
194
 
@@ -265,12 +371,20 @@ def filter_models(
265
  filtered_df = filtered_df[filtered_df[AutoEvalColumn.flagged.name] == False]
266
 
267
  type_emoji = [t[0] for t in type_query]
 
 
 
 
 
268
  filtered_df = filtered_df.loc[df[AutoEvalColumn.model_type_symbol.name].isin(type_emoji)]
269
  filtered_df = filtered_df.loc[df[AutoEvalColumn.precision.name].isin(precision_query + ["None"])]
270
-
271
  filtered_df = filtered_df.loc[df[AutoEvalColumn.weight_dtype.name].isin(weight_dtype)]
 
272
  filtered_df = filtered_df.loc[df[AutoEvalColumn.compute_dtype.name].isin(compute_dtype)]
 
273
  filtered_df = filtered_df.loc[df[AutoEvalColumn.double_quant.name].isin(double_quant)]
 
274
  filtered_df = filtered_df.loc[df[AutoEvalColumn.group_size.name].isin(group_dtype)]
275
 
276
  numeric_interval = pd.IntervalIndex(sorted([NUMERIC_INTERVALS[s] for s in size_query]))
@@ -310,7 +424,6 @@ def select(df, data: gr.SelectData):
310
  text_content = match.group(1)
311
  selected_values[text_content] = value
312
 
313
- print('selected_values', selected_values, selected_indices)
314
  return gr.CheckboxGroup(list(selected_values.keys()), value=list(selected_values.keys()))
315
 
316
  def init_comparison_data():
@@ -319,7 +432,6 @@ def init_comparison_data():
319
 
320
  def generate_spider_chart(df, selected_keys):
321
  global selected_values
322
- print('generate_spider_chart', selected_values, selected_keys)
323
  current_selected_values = [selected_values[key] for key in selected_keys if key in selected_values]
324
  selected_rows = df[df.iloc[:, 1].isin(current_selected_values)]
325
 
@@ -344,7 +456,7 @@ def generate_spider_chart(df, selected_keys):
344
 
345
  leaderboard_df = filter_models(
346
  df=leaderboard_df,
347
- type_query=[t.to_str(" : ") for t in QuantType],
348
  size_query=list(NUMERIC_INTERVALS.keys()),
349
  params_query=list(NUMERIC_MODELSIZE.keys()),
350
  precision_query=[i.value.name for i in Precision],
@@ -407,7 +519,7 @@ with demo:
407
  #with gr.Box(elem_id="box-filter"):
408
  filter_columns_type = gr.CheckboxGroup(
409
  label="Quantization types",
410
- choices=[t.to_str() for t in QuantType],
411
  value=[t.to_str() for t in QuantType if t != QuantType.QuantType_None],
412
  interactive=True,
413
  elem_id="filter-columns-type",
@@ -415,7 +527,7 @@ with demo:
415
  filter_columns_precision = gr.CheckboxGroup(
416
  label="Weight precision",
417
  choices=[i.value.name for i in Precision],
418
- value=[i.value.name for i in Precision],
419
  interactive=True,
420
  elem_id="filter-columns-precision",
421
  )
@@ -430,7 +542,7 @@ with demo:
430
 
431
  with gr.Row():
432
  with gr.Column():
433
- model_comparison = gr.CheckboxGroup(label="Accuracy Comparison on Selected Models", choices=list(selected_values.keys()), value=list(selected_values.keys()), interactive=True, elem_id="model_comparison")
434
  with gr.Column():
435
  spider_btn = gr.Button("Compare")
436
 
@@ -514,21 +626,30 @@ with demo:
514
  demo.load(load_query, inputs=[], outputs=[search_bar, hidden_search_bar])
515
 
516
  """
 
 
 
 
 
 
517
  filter_columns_precision.change(
518
- update_Weight_Dtype,
519
  [filter_columns_precision],
520
- [filter_columns_weightDtype]
521
  )
522
 
523
- filter_columns_type.change(
524
- quant_update_Weight_Dtype,
525
- [filter_columns_type],
526
- [filter_columns_weightDtype]
527
  )
528
-
 
 
 
 
529
 
530
 
531
-
532
 
533
  for selector in [shown_columns, filter_columns_type, filter_columns_precision, filter_columns_size, filter_columns_parameters, hide_models, filter_columns_computeDtype, filter_columns_weightDtype, filter_columns_doubleQuant, filter_columns_groupDtype]:
534
  selector.change(
@@ -680,4 +801,4 @@ scheduler.add_job(update_dynamic_files, "interval", hours=12) # launched every 2
680
  scheduler.start()
681
 
682
  demo.queue(default_concurrency_limit=40).launch()
683
- # demo.queue(concurrency_count=40).launch()
 
51
 
52
  selected_indices = []
53
  selected_values = {}
54
+ selected_dropdown_weight = 'All'
55
 
56
  # Start ephemeral Spaces on PRs (see config in README.md)
57
  #enable_space_ci()
 
60
  "2bit": ["int2"],
61
  "3bit": ["int3"],
62
  "4bit": ["int4", "nf4", "fp4"],
63
+ "8bit": ["int8"],
64
+ "16bit": ['float16', 'bfloat16'],
65
+ "32bit": ["float32"],
66
+ "?": ["?"],
67
  }
68
 
69
+ dtype_to_precision = {
70
+ "int2": ["2bit"],
71
+ "int3": ["3bit"],
72
+ "int4": ["4bit"],
73
+ "nf4": ["4bit"],
74
+ "fp4": ["4bit"],
75
+ "int8": ["8bit"],
76
+ "float16": ["16bit"],
77
+ "bfloat16": ["16bit"],
78
+ "float32": ["32bit"],
79
+ "?": ["?"],
80
+ }
81
 
82
+ current_weightDtype = ["int2", "int3", "int4", "nf4", "fp4", "?"]
83
+ current_computeDtype = ['int8', 'bfloat16', 'float16', 'float32']
84
+ current_quant = [t.to_str() for t in QuantType if t != QuantType.QuantType_None]
85
+ current_precision = ['2bit', '3bit', '4bit', '8bit', '?']
86
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def display_sort(key):
89
+ order = {"All": 0, "?": 1, "int2": 2, "int3": 3, "int4": 4, "fp4": 5, "nf4": 6, "float16": 7, "bfloat16": 8, "float32": 9}
90
+ return order.get(key, float('inf'))
91
 
92
+ def comp_display_sort(key):
93
+ order = {"All": 0, "?": 1, "int8": 2, "float16": 3, "bfloat16": 4, "float32": 5}
94
+ return order.get(key, float('inf'))
95
+
96
+ def update_quantization_types(selected_quant):
97
  global current_weightDtype
98
+ global current_computeDtype
99
+ global current_quant
100
+ global current_precision
101
+
102
+ if set(current_quant) == set(selected_quant):
103
+ return [
104
+ gr.Dropdown(choices=current_weightDtype, value=selected_dropdown_weight),
105
+ gr.Dropdown(choices=current_computeDtype, value="All"),
106
+ gr.CheckboxGroup(value=current_precision),
107
+ ]
108
+
109
+ print('update_quantization_types', selected_quant, current_quant)
110
+ if any(value != 'βœ– None' for value in selected_quant):
111
+ selected_weight = ['All', '?', 'int2', 'int3', 'int4', 'nf4', 'fp4', 'int8']
112
+ selected_compute = ['All', '?', 'int8', 'float16', 'bfloat16', 'float32']
113
+ selected_precision = ["2bit", "3bit", "4bit", "8bit", "?"]
114
+
115
+ current_weightDtype = selected_weight
116
+ current_computeDtype = selected_compute
117
+ current_quant = selected_quant
118
+ current_precision = selected_precision
119
+
120
+ return [
121
+ gr.Dropdown(choices=selected_weight, value="All"),
122
+ gr.Dropdown(choices=selected_compute, value="All"),
123
+ gr.CheckboxGroup(value=selected_precision),
124
+ ]
125
 
126
+ def update_Weight_Precision(temp_precisions):
127
+ global current_weightDtype
128
+ global current_computeDtype
129
+ global current_quant
130
+ global current_precision
131
+ global selected_dropdown_weight
132
 
133
+
134
+ if set(current_precision) == set(temp_precisions):
135
+ return [
136
+ gr.Dropdown(choices=current_weightDtype, value=selected_dropdown_weight),
137
+ gr.Dropdown(choices=current_computeDtype, value="All"),
138
+ gr.CheckboxGroup(value=current_precision),
139
+ gr.CheckboxGroup(value=current_quant),
140
+ ] # No update needed
141
+ if len(temp_precisions) > 1:
142
+ selected_dropdown_weight = 'All'
143
+ precisions = [precision for precision in temp_precisions if precision not in current_precision]
144
+
145
+ selected_weight = []
146
+ selected_compute = ['All', '?', 'int8', 'float16', 'bfloat16', 'float32']
147
+ selected_quant = [t.to_str() for t in QuantType if t != QuantType.QuantType_None]
148
+
149
+ if set(temp_precisions) == {"16bit", "32bit"}:
150
+ selected_precisions = temp_precisions.copy()
151
+ elif (temp_precisions == ["16bit"] or temp_precisions == ["32bit"]) and current_precision == ["16bit", "32bit"]:
152
+ selected_precisions = temp_precisions
153
+ else:
154
+ if "16bit" in precisions or "32bit" in precisions:
155
+ selected_precisions = precisions
156
+ else:
157
+ filtered_precisions = [p for p in temp_precisions if p not in ["16bit", "32bit"]]
158
+ selected_precisions = filtered_precisions.copy()
159
+
160
+ current_precision = selected_precisions
161
 
162
+ # Map selected_precisions to corresponding weights
163
+ for precision in current_precision:
164
+ if precision in precision_to_dtype:
165
+ selected_weight.extend(precision_to_dtype[precision])
166
+
167
+ # Special rules for 16bit and 32bit
168
+ if "16bit" in current_precision:
169
+ selected_weight = [option for option in selected_weight if option in ["All", "?", "float16", "bfloat16"]]
170
+ if "int8" in selected_compute:
171
+ selected_compute.remove("int8")
172
+
173
+ if "32bit" in current_precision:
174
+ selected_weight = [option for option in selected_weight if option in ["All", "?", "float32"]]
175
+ if "int8" in selected_compute:
176
+ selected_compute.remove("int8")
177
+
178
+ if "16bit" in current_precision or "32bit" in current_precision:
179
+ selected_quant = ['βœ– None']
180
+ if "16bit" in current_precision and "32bit" in current_precision:
181
+ selected_weight = ["All", "?", "float16", "bfloat16", "float32"]
182
+ # Ensure "All" and "?" options are included
183
+ selected_weight = ["All", "?"] + [opt for opt in selected_weight if opt not in ["All", "?"]]
184
+ selected_compute = ["All", "?"] + [opt for opt in selected_compute if opt not in ["All", "?"]]
185
 
186
+ # Remove duplicates
187
+ selected_weight = list(set(selected_weight))
188
+ selected_compute = list(set(selected_compute))
189
+
190
+ # Update global variables
191
+ current_weightDtype = selected_weight
192
+ current_computeDtype = selected_compute
193
+ current_quant = selected_quant
194
+
195
+ # Return updated components
196
+ return [
197
+ gr.Dropdown(choices=selected_weight, value=selected_dropdown_weight),
198
+ gr.Dropdown(choices=selected_compute, value="All"),
199
+ gr.CheckboxGroup(value=selected_precisions),
200
+ gr.CheckboxGroup(value=selected_quant),
201
+ ]
202
 
203
+ def update_Weight_Dtype(weight):
204
+ global selected_dropdown_weight
205
+ print('update_Weight_Dtype', weight)
206
+ # Initialize selected_precisions
207
+ if weight == selected_dropdown_weight or weight == 'All':
208
+ return current_precision
209
+ else:
210
+ selected_precisions = []
211
+ selected_precisions.extend(dtype_to_precision[weight])
212
+ selected_dropdown_weight = weight
213
+ print('selected_precisions', selected_precisions)
214
+ # Return updated components
215
+ return selected_precisions
216
 
 
 
217
 
218
 
219
 
 
286
  ):
287
  global init_select
288
  global current_weightDtype
289
+ global current_computeDtype
290
 
291
+ if weight_dtype == ['All'] or weight_dtype == 'All':
 
292
  weight_dtype = current_weightDtype
 
 
 
293
  else:
294
  weight_dtype = [weight_dtype]
295
 
296
  if compute_dtype == 'All':
297
+ compute_dtype = current_computeDtype
298
  else:
299
  compute_dtype = [compute_dtype]
300
 
 
371
  filtered_df = filtered_df[filtered_df[AutoEvalColumn.flagged.name] == False]
372
 
373
  type_emoji = [t[0] for t in type_query]
374
+ if any(emoji != 'βœ–' for emoji in type_emoji):
375
+ type_emoji = [emoji for emoji in type_emoji if emoji != 'βœ–']
376
+ else:
377
+ type_emoji = ['βœ–']
378
+
379
  filtered_df = filtered_df.loc[df[AutoEvalColumn.model_type_symbol.name].isin(type_emoji)]
380
  filtered_df = filtered_df.loc[df[AutoEvalColumn.precision.name].isin(precision_query + ["None"])]
381
+
382
  filtered_df = filtered_df.loc[df[AutoEvalColumn.weight_dtype.name].isin(weight_dtype)]
383
+
384
  filtered_df = filtered_df.loc[df[AutoEvalColumn.compute_dtype.name].isin(compute_dtype)]
385
+
386
  filtered_df = filtered_df.loc[df[AutoEvalColumn.double_quant.name].isin(double_quant)]
387
+
388
  filtered_df = filtered_df.loc[df[AutoEvalColumn.group_size.name].isin(group_dtype)]
389
 
390
  numeric_interval = pd.IntervalIndex(sorted([NUMERIC_INTERVALS[s] for s in size_query]))
 
424
  text_content = match.group(1)
425
  selected_values[text_content] = value
426
 
 
427
  return gr.CheckboxGroup(list(selected_values.keys()), value=list(selected_values.keys()))
428
 
429
  def init_comparison_data():
 
432
 
433
  def generate_spider_chart(df, selected_keys):
434
  global selected_values
 
435
  current_selected_values = [selected_values[key] for key in selected_keys if key in selected_values]
436
  selected_rows = df[df.iloc[:, 1].isin(current_selected_values)]
437
 
 
456
 
457
  leaderboard_df = filter_models(
458
  df=leaderboard_df,
459
+ type_query=[t.to_str(" : ") for t in QuantType if t != QuantType.QuantType_None],
460
  size_query=list(NUMERIC_INTERVALS.keys()),
461
  params_query=list(NUMERIC_MODELSIZE.keys()),
462
  precision_query=[i.value.name for i in Precision],
 
519
  #with gr.Box(elem_id="box-filter"):
520
  filter_columns_type = gr.CheckboxGroup(
521
  label="Quantization types",
522
+ choices=[t.to_str() for t in QuantType if t != QuantType.QuantType_None],
523
  value=[t.to_str() for t in QuantType if t != QuantType.QuantType_None],
524
  interactive=True,
525
  elem_id="filter-columns-type",
 
527
  filter_columns_precision = gr.CheckboxGroup(
528
  label="Weight precision",
529
  choices=[i.value.name for i in Precision],
530
+ value=[i.value.name for i in Precision if ( i.value.name != '16bit' and i.value.name != '32bit')],
531
  interactive=True,
532
  elem_id="filter-columns-precision",
533
  )
 
542
 
543
  with gr.Row():
544
  with gr.Column():
545
+ model_comparison = gr.CheckboxGroup(label="Accuracy Comparison (Selected Models from Table)", choices=list(selected_values.keys()), value=list(selected_values.keys()), interactive=True, elem_id="model_comparison")
546
  with gr.Column():
547
  spider_btn = gr.Button("Compare")
548
 
 
626
  demo.load(load_query, inputs=[], outputs=[search_bar, hidden_search_bar])
627
 
628
  """
629
+ filter_columns_type.change(
630
+ update_quantization_types,
631
+ [filter_columns_type],
632
+ [filter_columns_weightDtype, filter_columns_computeDtype, filter_columns_precision]
633
+ )
634
+
635
  filter_columns_precision.change(
636
+ update_Weight_Precision,
637
  [filter_columns_precision],
638
+ [filter_columns_weightDtype, filter_columns_computeDtype, filter_columns_precision, filter_columns_type]
639
  )
640
 
641
+ filter_columns_weightDtype.change(
642
+ update_Weight_Dtype,
643
+ [filter_columns_weightDtype],
644
+ [filter_columns_precision]
645
  )
646
+ # filter_columns_computeDtype.change(
647
+ # Compute_Dtype_update,
648
+ # [filter_columns_computeDtype, filter_columns_precision],
649
+ # [filter_columns_precision, filter_columns_type]
650
+ # )
651
 
652
 
 
653
 
654
  for selector in [shown_columns, filter_columns_type, filter_columns_precision, filter_columns_size, filter_columns_parameters, hide_models, filter_columns_computeDtype, filter_columns_weightDtype, filter_columns_doubleQuant, filter_columns_groupDtype]:
655
  selector.change(
 
801
  scheduler.start()
802
 
803
  demo.queue(default_concurrency_limit=40).launch()
804
+ # demo.queue(concurrency_count=40).launch()