lewtun HF staff commited on
Commit
348bdab
·
1 Parent(s): ecefcd1

Add cross-linking of QA models

Browse files
Files changed (2) hide show
  1. app.py +7 -2
  2. utils.py +21 -12
app.py CHANGED
@@ -67,6 +67,7 @@ def get_supported_metrics():
67
  # in the same environment. Refactor to avoid needing to actually load
68
  # the metric.
69
  try:
 
70
  metric_func = load(metric)
71
  except Exception as e:
72
  print(e)
@@ -103,7 +104,7 @@ st.markdown(
103
  Welcome to Hugging Face's automatic model evaluator! This application allows
104
  you to evaluate 🤗 Transformers
105
  [models](https://huggingface.co/models?library=transformers&sort=downloads)
106
- across a wide variety of datasets on the Hub -- all for free! Please select
107
  the dataset and configuration below. The results of your evaluation will be
108
  displayed on the [public
109
  leaderboard](https://huggingface.co/spaces/autoevaluate/leaderboards).
@@ -345,8 +346,12 @@ with st.expander("Advanced configuration"):
345
  )
346
 
347
  with st.form(key="form"):
 
 
 
 
 
348
 
349
- compatible_models = get_compatible_models(selected_task, selected_dataset)
350
  selected_models = st.multiselect(
351
  "Select the models you wish to evaluate",
352
  compatible_models,
 
67
  # in the same environment. Refactor to avoid needing to actually load
68
  # the metric.
69
  try:
70
+ print(f"INFO -- Attempting to load metric: {metric}")
71
  metric_func = load(metric)
72
  except Exception as e:
73
  print(e)
 
104
  Welcome to Hugging Face's automatic model evaluator! This application allows
105
  you to evaluate 🤗 Transformers
106
  [models](https://huggingface.co/models?library=transformers&sort=downloads)
107
+ across a wide variety of datasets on the Hub. Please select
108
  the dataset and configuration below. The results of your evaluation will be
109
  displayed on the [public
110
  leaderboard](https://huggingface.co/spaces/autoevaluate/leaderboards).
 
346
  )
347
 
348
  with st.form(key="form"):
349
+ # Grab all models fine-tuned on SQuAD for question answering tasks
350
+ if selected_task == "extractive_question_answering":
351
+ compatible_models = get_compatible_models(selected_task, [selected_dataset, "squad", "squad_v2"])
352
+ else:
353
+ compatible_models = get_compatible_models(selected_task, [selected_dataset])
354
 
 
355
  selected_models = st.multiselect(
356
  "Select the models you wish to evaluate",
357
  compatible_models,
utils.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Union
2
 
3
  import jsonlines
4
  import requests
@@ -19,9 +19,6 @@ HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items(
19
  LOGS_REPO = "evaluation-job-logs"
20
 
21
 
22
- api = HfApi()
23
-
24
-
25
  def get_auth_headers(token: str, prefix: str = "autonlp"):
26
  return {"Authorization": f"{prefix} {token}"}
27
 
@@ -65,14 +62,26 @@ def get_metadata(dataset_name: str) -> Union[Dict, None]:
65
  return None
66
 
67
 
68
- def get_compatible_models(task, dataset_name):
69
- # TODO: relax filter on PyTorch models once supported in AutoTrain
70
- filt = ModelFilter(
71
- task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
72
- trained_dataset=dataset_name,
73
- library=["transformers", "pytorch"],
74
- )
75
- compatible_models = api.list_models(filter=filt)
 
 
 
 
 
 
 
 
 
 
 
 
76
  return sorted([model.modelId for model in compatible_models])
77
 
78
 
 
1
+ from typing import Dict, List, Union
2
 
3
  import jsonlines
4
  import requests
 
19
  LOGS_REPO = "evaluation-job-logs"
20
 
21
 
 
 
 
22
  def get_auth_headers(token: str, prefix: str = "autonlp"):
23
  return {"Authorization": f"{prefix} {token}"}
24
 
 
62
  return None
63
 
64
 
65
+ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
66
+ """
67
+ Returns all model IDs that are compatible with the given task and dataset names.
68
+
69
+ Args:
70
+ task (`str`): The task to search for.
71
+ dataset_names (`List[str]`): A list of dataset names to search for.
72
+
73
+ Returns:
74
+ A list of model IDs, sorted alphabetically.
75
+ """
76
+ # TODO: relax filter on PyTorch models if TensorFlow supported in AutoTrain
77
+ compatible_models = []
78
+ for dataset_id in dataset_ids:
79
+ model_filter = ModelFilter(
80
+ task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
81
+ trained_dataset=dataset_id,
82
+ library=["transformers", "pytorch"],
83
+ )
84
+ compatible_models.extend(HfApi().list_models(filter=model_filter))
85
  return sorted([model.modelId for model in compatible_models])
86
 
87