lewtun HF staff commited on
Commit
8dec3b6
·
1 Parent(s): 160673c

Refactor task names

Browse files
Files changed (2) hide show
  1. app.py +4 -20
  2. utils.py +14 -1
app.py CHANGED
@@ -29,19 +29,6 @@ TASK_TO_ID = {
29
  # "single_column_regression": 10,
30
  }
31
 
32
- AUTOTRAIN_TASK_TO_HUB_TASK = {
33
- "binary_classification": "text-classification",
34
- "multi_class_classification": "text-classification",
35
- # "multi_label_classification": "text-classification", # Not fully supported in AutoTrain
36
- "entity_extraction": "token-classification",
37
- "extractive_question_answering": "question-answering",
38
- "translation": "translation",
39
- "summarization": "summarization",
40
- # "single_column_regression": 10,
41
- }
42
-
43
- HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
44
-
45
  ###########
46
  ### APP ###
47
  ###########
@@ -74,7 +61,7 @@ if metadata is None:
74
 
75
  with st.expander("Advanced configuration"):
76
  ## Select task
77
- selected_task = st.selectbox("Select a task", list(AUTOTRAIN_TASK_TO_HUB_TASK.values()))
78
  ### Select config
79
  configs = get_dataset_config_names(selected_dataset)
80
  selected_config = st.selectbox("Select a config", configs)
@@ -84,9 +71,7 @@ with st.expander("Advanced configuration"):
84
  if splits_resp.status_code == 200:
85
  split_names = []
86
  all_splits = splits_resp.json()
87
- print(all_splits)
88
  for split in all_splits["splits"]:
89
- print(selected_config)
90
  if split["config"] == selected_config:
91
  split_names.append(split["split"])
92
 
@@ -120,7 +105,7 @@ with st.expander("Advanced configuration"):
120
  # TODO: make it task specific
121
  col_mapping = {}
122
  with col1:
123
- if selected_task == "text-classification":
124
  st.markdown("`text` column")
125
  st.text("")
126
  st.text("")
@@ -153,11 +138,10 @@ with st.form(key="form"):
153
 
154
  if submit_button:
155
  project_id = str(uuid.uuid4())[:3]
156
- autotrain_task_name = HUB_TASK_TO_AUTOTRAIN_TASK[selected_task]
157
  payload = {
158
  "username": AUTOTRAIN_USERNAME,
159
  "proj_name": f"my-eval-project-{project_id}",
160
- "task": TASK_TO_ID[autotrain_task_name],
161
  "config": {
162
  "language": "en",
163
  "max_models": 5,
@@ -181,7 +165,7 @@ with st.form(key="form"):
181
 
182
  if project_json_resp["created"]:
183
  payload = {
184
- "split": 4,
185
  "col_mapping": col_mapping,
186
  "load_config": {"max_size_bytes": 0, "shuffle": False},
187
  }
 
29
  # "single_column_regression": 10,
30
  }
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ###########
33
  ### APP ###
34
  ###########
 
61
 
62
  with st.expander("Advanced configuration"):
63
  ## Select task
64
+ selected_task = st.selectbox("Select a task", list(TASK_TO_ID.keys()))
65
  ### Select config
66
  configs = get_dataset_config_names(selected_dataset)
67
  selected_config = st.selectbox("Select a config", configs)
 
71
  if splits_resp.status_code == 200:
72
  split_names = []
73
  all_splits = splits_resp.json()
 
74
  for split in all_splits["splits"]:
 
75
  if split["config"] == selected_config:
76
  split_names.append(split["split"])
77
 
 
105
  # TODO: make it task specific
106
  col_mapping = {}
107
  with col1:
108
+ if selected_task in ["binary_classification", "multi_class_classification"]:
109
  st.markdown("`text` column")
110
  st.text("")
111
  st.text("")
 
138
 
139
  if submit_button:
140
  project_id = str(uuid.uuid4())[:3]
 
141
  payload = {
142
  "username": AUTOTRAIN_USERNAME,
143
  "proj_name": f"my-eval-project-{project_id}",
144
+ "task": TASK_TO_ID[selected_task],
145
  "config": {
146
  "language": "en",
147
  "max_models": 5,
 
165
 
166
  if project_json_resp["created"]:
167
  payload = {
168
+ "split": 4, # use "auto" split choice in AutoTrain
169
  "col_mapping": col_mapping,
170
  "load_config": {"max_size_bytes": 0, "shuffle": False},
171
  }
utils.py CHANGED
@@ -3,6 +3,19 @@ from typing import Dict, Union
3
  import requests
4
  from huggingface_hub import DatasetFilter, HfApi, ModelFilter
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  api = HfApi()
7
 
8
 
@@ -44,6 +57,6 @@ def get_metadata(dataset_name: str) -> Union[Dict, None]:
44
 
45
 
46
  def get_compatible_models(task, dataset_name):
47
- filt = ModelFilter(task=task, trained_dataset=dataset_name, library="transformers")
48
  compatible_models = api.list_models(filter=filt)
49
  return [model.modelId for model in compatible_models]
 
3
  import requests
4
  from huggingface_hub import DatasetFilter, HfApi, ModelFilter
5
 
6
+ AUTOTRAIN_TASK_TO_HUB_TASK = {
7
+ "binary_classification": "text-classification",
8
+ "multi_class_classification": "text-classification",
9
+ # "multi_label_classification": "text-classification", # Not fully supported in AutoTrain
10
+ "entity_extraction": "token-classification",
11
+ "extractive_question_answering": "question-answering",
12
+ "translation": "translation",
13
+ "summarization": "summarization",
14
+ # "single_column_regression": 10,
15
+ }
16
+
17
+ HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
18
+
19
  api = HfApi()
20
 
21
 
 
57
 
58
 
59
  def get_compatible_models(task, dataset_name):
60
+ filt = ModelFilter(task=AUTOTRAIN_TASK_TO_HUB_TASK[task], trained_dataset=dataset_name, library="transformers")
61
  compatible_models = api.list_models(filter=filt)
62
  return [model.modelId for model in compatible_models]