Add input for custom gene token dictionary and add str to valid options

#327
by hchen725 - opened
Files changed (1) hide show
  1. geneformer/classifier.py +13 -3
geneformer/classifier.py CHANGED
@@ -86,6 +86,7 @@ class Classifier:
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
 
89
  "nproc": {int},
90
  "ngpu": {int},
91
  }
@@ -107,6 +108,7 @@ class Classifier:
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
 
110
  nproc=4,
111
  ngpu=1,
112
  ):
@@ -175,6 +177,9 @@ class Classifier:
175
  | Otherwise, will perform eval during training.
176
  forward_batch_size : int
177
  | Batch size for forward pass (for evaluation, not training).
 
 
 
178
  nproc : int
179
  | Number of CPU processes to use.
180
  ngpu : int
@@ -201,9 +206,10 @@ class Classifier:
201
  self.stratify_splits_col = stratify_splits_col
202
  self.no_eval = no_eval
203
  self.forward_batch_size = forward_batch_size
 
204
  self.nproc = nproc
205
  self.ngpu = ngpu
206
-
207
  if self.training_args is None:
208
  logger.warning(
209
  "Hyperparameter tuning is highly recommended for optimal results. "
@@ -222,7 +228,10 @@ class Classifier:
222
  ] = self.cell_state_dict["states"]
223
 
224
  # load token dictionary (Ensembl IDs:token)
225
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
 
 
 
226
  self.gene_token_dict = pickle.load(f)
227
 
228
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
@@ -267,7 +276,7 @@ class Classifier:
267
  continue
268
  valid_type = False
269
  for option in valid_options:
270
- if (option in [int, float, list, dict, bool]) and isinstance(
271
  attr_value, option
272
  ):
273
  valid_type = True
@@ -1018,6 +1027,7 @@ class Classifier:
1018
  metric="eval_macro_f1",
1019
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1020
  ),
 
1021
  )
1022
 
1023
  return trainer
 
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
89
+ "gene_token_dict_path": {None, str},
90
  "nproc": {int},
91
  "ngpu": {int},
92
  }
 
108
  stratify_splits_col=None,
109
  no_eval=False,
110
  forward_batch_size=100,
111
+ gene_token_dict_path=None,
112
  nproc=4,
113
  ngpu=1,
114
  ):
 
177
  | Otherwise, will perform eval during training.
178
  forward_batch_size : int
179
  | Batch size for forward pass (for evaluation, not training).
180
+ gene_token_dict_path : None, str
181
+ | Default is to use token dictionary file from Geneformer
182
+ | Otherwise, will load custom gene token dictionary.
183
  nproc : int
184
  | Number of CPU processes to use.
185
  ngpu : int
 
206
  self.stratify_splits_col = stratify_splits_col
207
  self.no_eval = no_eval
208
  self.forward_batch_size = forward_batch_size
209
+ self.gene_token_dict_path = gene_token_dict_path
210
  self.nproc = nproc
211
  self.ngpu = ngpu
212
+
213
  if self.training_args is None:
214
  logger.warning(
215
  "Hyperparameter tuning is highly recommended for optimal results. "
 
228
  ] = self.cell_state_dict["states"]
229
 
230
  # load token dictionary (Ensembl IDs:token)
231
+ if self.gene_token_dict_path is None:
232
+ self.gene_token_dict_path = TOKEN_DICTIONARY_FILE
233
+
234
+ with open(self.gene_token_dict_path, "rb") as f:
235
  self.gene_token_dict = pickle.load(f)
236
 
237
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
 
276
  continue
277
  valid_type = False
278
  for option in valid_options:
279
+ if (option in [int, float, list, dict, bool, str]) and isinstance(
280
  attr_value, option
281
  ):
282
  valid_type = True
 
1027
  metric="eval_macro_f1",
1028
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1029
  ),
1030
+ local_dir = f"{output_directory}/ray_results", # HAN ADDED
1031
  )
1032
 
1033
  return trainer