fix metrics for LongEmbed

#124
by dwzhu - opened
Files changed (2) hide show
  1. app.py +15 -2
  2. config.yaml +1 -0
app.py CHANGED
@@ -116,8 +116,16 @@ for model in pbar:
116
  ds = ds.map(add_task)
117
  base_dict = {"Model": make_clickable_model(model, link=EXTERNAL_MODEL_TO_LINK.get(model, f"https://huggingface.co/spaces/{REPO_ID}"))}
118
  # For now only one metric per task - Could add more metrics lateron
 
 
 
 
 
 
 
 
119
  for task, metric in TASK_TO_METRIC.items():
120
- ds_dict = ds.filter(lambda x: (x["mteb_task"] == task) and (x["metric"] == metric))["test"].to_dict()
121
  ds_dict = {k: round(v, 2) for k, v in zip(ds_dict["mteb_dataset_name_with_lang"], ds_dict["score"])}
122
  EXTERNAL_MODEL_RESULTS[model][task][metric].append({**base_dict, **ds_dict})
123
 
@@ -457,6 +465,7 @@ for board, board_config in BOARDS_CONFIG.items():
457
  "data": boards_data[board]["data_tasks"][task_category],
458
  "refresh": get_refresh_function(task_category, task_category_list),
459
  "credits": credits,
 
460
  })
461
 
462
  dataframes = []
@@ -612,11 +621,15 @@ with gr.Blocks(css=css) as block:
612
  # For updating the 'language' in the URL
613
  item_tab.select(update_url_language, [current_task_language, language_per_task], [current_task_language, language_per_task], trigger_mode="always_last").then(None, [current_task_language], [], js=set_window_url_params)
614
 
 
 
 
 
615
  with gr.Row():
616
  gr.Markdown(f"""
617
  {item['description']}
618
 
619
- - **Metric:** {metric}
620
  - **Languages:** {item['language_long'] if 'language_long' in item else item['language']}
621
  {"- **Credits:** " + item['credits'] if ("credits" in item and item["credits"] is not None) else ''}
622
  """)
 
116
  ds = ds.map(add_task)
117
  base_dict = {"Model": make_clickable_model(model, link=EXTERNAL_MODEL_TO_LINK.get(model, f"https://huggingface.co/spaces/{REPO_ID}"))}
118
  # For now only one metric per task - Could add more metrics lateron
119
+
120
+ def filter_function(x, task, metric):
121
+ # This is a hack for the passkey and needle retrieval test, which reports ndcg_at_1 (i.e. accuracy), rather than the ndcg_at_10 that is commonly used for retrieval tasks.
122
+ if x['mteb_dataset_name'] in ['LEMBNeedleRetrieval', 'LEMBPasskeyRetrieval']:
123
+ return x["mteb_task"] == task and x['metric'] == 'ndcg_at_1'
124
+ else:
125
+ return x["mteb_task"] == task and x["metric"] == metric
126
+
127
  for task, metric in TASK_TO_METRIC.items():
128
+ ds_dict = ds.filter(lambda x: filter_function(x, task, metric))["test"].to_dict()
129
  ds_dict = {k: round(v, 2) for k, v in zip(ds_dict["mteb_dataset_name_with_lang"], ds_dict["score"])}
130
  EXTERNAL_MODEL_RESULTS[model][task][metric].append({**base_dict, **ds_dict})
131
 
 
465
  "data": boards_data[board]["data_tasks"][task_category],
466
  "refresh": get_refresh_function(task_category, task_category_list),
467
  "credits": credits,
468
+ "metric": board_config.get("metric", None),
469
  })
470
 
471
  dataframes = []
 
621
  # For updating the 'language' in the URL
622
  item_tab.select(update_url_language, [current_task_language, language_per_task], [current_task_language, language_per_task], trigger_mode="always_last").then(None, [current_task_language], [], js=set_window_url_params)
623
 
624
+ specific_metric = metric
625
+ if item.get("metric", None) is not None:
626
+ specific_metric = item['metric']
627
+
628
  with gr.Row():
629
  gr.Markdown(f"""
630
  {item['description']}
631
 
632
+ - **Metric:** {specific_metric}
633
  - **Languages:** {item['language_long'] if 'language_long' in item else item['language']}
634
  {"- **Credits:** " + item['credits'] if ("credits" in item and item["credits"] is not None) else ''}
635
  """)
config.yaml CHANGED
@@ -301,6 +301,7 @@ boards:
301
  icon: "πŸ“š"
302
  special_icons: null
303
  credits: "[LongEmbed](https://arxiv.org/abs/2404.12096v2)"
 
304
  tasks:
305
  Retrieval:
306
  - LEMBNarrativeQARetrieval
 
301
  icon: "πŸ“š"
302
  special_icons: null
303
  credits: "[LongEmbed](https://arxiv.org/abs/2404.12096v2)"
304
+ metric: nDCG@10 (for NarrativeQA, QMSum, SummScreenFD, WikimQA) & nDCG@1 (for passkey and needle)
305
  tasks:
306
  Retrieval:
307
  - LEMBNarrativeQARetrieval