Add input for custom gene token dictionary and add str to valid options
#327
by
hchen725
- opened
- geneformer/classifier.py +13 -3
geneformer/classifier.py
CHANGED
@@ -86,6 +86,7 @@ class Classifier:
|
|
86 |
"no_eval": {bool},
|
87 |
"stratify_splits_col": {None, str},
|
88 |
"forward_batch_size": {int},
|
|
|
89 |
"nproc": {int},
|
90 |
"ngpu": {int},
|
91 |
}
|
@@ -107,6 +108,7 @@ class Classifier:
|
|
107 |
stratify_splits_col=None,
|
108 |
no_eval=False,
|
109 |
forward_batch_size=100,
|
|
|
110 |
nproc=4,
|
111 |
ngpu=1,
|
112 |
):
|
@@ -175,6 +177,9 @@ class Classifier:
|
|
175 |
| Otherwise, will perform eval during training.
|
176 |
forward_batch_size : int
|
177 |
| Batch size for forward pass (for evaluation, not training).
|
|
|
|
|
|
|
178 |
nproc : int
|
179 |
| Number of CPU processes to use.
|
180 |
ngpu : int
|
@@ -201,9 +206,10 @@ class Classifier:
|
|
201 |
self.stratify_splits_col = stratify_splits_col
|
202 |
self.no_eval = no_eval
|
203 |
self.forward_batch_size = forward_batch_size
|
|
|
204 |
self.nproc = nproc
|
205 |
self.ngpu = ngpu
|
206 |
-
|
207 |
if self.training_args is None:
|
208 |
logger.warning(
|
209 |
"Hyperparameter tuning is highly recommended for optimal results. "
|
@@ -222,7 +228,10 @@ class Classifier:
|
|
222 |
] = self.cell_state_dict["states"]
|
223 |
|
224 |
# load token dictionary (Ensembl IDs:token)
|
225 |
-
|
|
|
|
|
|
|
226 |
self.gene_token_dict = pickle.load(f)
|
227 |
|
228 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
@@ -267,7 +276,7 @@ class Classifier:
|
|
267 |
continue
|
268 |
valid_type = False
|
269 |
for option in valid_options:
|
270 |
-
if (option in [int, float, list, dict, bool]) and isinstance(
|
271 |
attr_value, option
|
272 |
):
|
273 |
valid_type = True
|
@@ -1018,6 +1027,7 @@ class Classifier:
|
|
1018 |
metric="eval_macro_f1",
|
1019 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1020 |
),
|
|
|
1021 |
)
|
1022 |
|
1023 |
return trainer
|
|
|
86 |
"no_eval": {bool},
|
87 |
"stratify_splits_col": {None, str},
|
88 |
"forward_batch_size": {int},
|
89 |
+
"gene_token_dict_path": {None, str},
|
90 |
"nproc": {int},
|
91 |
"ngpu": {int},
|
92 |
}
|
|
|
108 |
stratify_splits_col=None,
|
109 |
no_eval=False,
|
110 |
forward_batch_size=100,
|
111 |
+
gene_token_dict_path=None,
|
112 |
nproc=4,
|
113 |
ngpu=1,
|
114 |
):
|
|
|
177 |
| Otherwise, will perform eval during training.
|
178 |
forward_batch_size : int
|
179 |
| Batch size for forward pass (for evaluation, not training).
|
180 |
+
gene_token_dict_path : None, str
|
181 |
+
| Default is to use token dictionary file from Geneformer
|
182 |
+
| Otherwise, will load custom gene token dictionary.
|
183 |
nproc : int
|
184 |
| Number of CPU processes to use.
|
185 |
ngpu : int
|
|
|
206 |
self.stratify_splits_col = stratify_splits_col
|
207 |
self.no_eval = no_eval
|
208 |
self.forward_batch_size = forward_batch_size
|
209 |
+
self.gene_token_dict_path = gene_token_dict_path
|
210 |
self.nproc = nproc
|
211 |
self.ngpu = ngpu
|
212 |
+
|
213 |
if self.training_args is None:
|
214 |
logger.warning(
|
215 |
"Hyperparameter tuning is highly recommended for optimal results. "
|
|
|
228 |
] = self.cell_state_dict["states"]
|
229 |
|
230 |
# load token dictionary (Ensembl IDs:token)
|
231 |
+
if self.gene_token_dict_path is None:
|
232 |
+
self.gene_token_dict_path = TOKEN_DICTIONARY_FILE
|
233 |
+
|
234 |
+
with open(self.gene_token_dict_path, "rb") as f:
|
235 |
self.gene_token_dict = pickle.load(f)
|
236 |
|
237 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
|
276 |
continue
|
277 |
valid_type = False
|
278 |
for option in valid_options:
|
279 |
+
if (option in [int, float, list, dict, bool, str]) and isinstance(
|
280 |
attr_value, option
|
281 |
):
|
282 |
valid_type = True
|
|
|
1027 |
metric="eval_macro_f1",
|
1028 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1029 |
),
|
1030 |
+
local_dir = f"{output_directory}/ray_results", # HAN ADDED
|
1031 |
)
|
1032 |
|
1033 |
return trainer
|