Christina Theodoris commited on
Commit
75c67a1
·
1 Parent(s): fb130e6

patch datasets save_to_disk

Browse files
geneformer/classifier.py CHANGED
@@ -445,8 +445,8 @@ class Classifier:
445
  test_data_output_path = (
446
  Path(output_directory) / f"{output_prefix}_labeled_test"
447
  ).with_suffix(".dataset")
448
- data_dict["train"].save_to_disk(train_data_output_path)
449
- data_dict["test"].save_to_disk(test_data_output_path)
450
  elif (test_size is not None) and (self.classifier == "cell"):
451
  if 1 > test_size > 0:
452
  if attr_to_split is None:
@@ -461,8 +461,8 @@ class Classifier:
461
  test_data_output_path = (
462
  Path(output_directory) / f"{output_prefix}_labeled_test"
463
  ).with_suffix(".dataset")
464
- data_dict["train"].save_to_disk(train_data_output_path)
465
- data_dict["test"].save_to_disk(test_data_output_path)
466
  else:
467
  data_dict, balance_df = cu.balance_attr_splits(
468
  data,
@@ -483,19 +483,19 @@ class Classifier:
483
  test_data_output_path = (
484
  Path(output_directory) / f"{output_prefix}_labeled_test"
485
  ).with_suffix(".dataset")
486
- data_dict["train"].save_to_disk(train_data_output_path)
487
- data_dict["test"].save_to_disk(test_data_output_path)
488
  else:
489
  data_output_path = (
490
  Path(output_directory) / f"{output_prefix}_labeled"
491
  ).with_suffix(".dataset")
492
- data.save_to_disk(data_output_path)
493
  print(data_output_path)
494
  else:
495
  data_output_path = (
496
  Path(output_directory) / f"{output_prefix}_labeled"
497
  ).with_suffix(".dataset")
498
- data.save_to_disk(data_output_path)
499
 
500
  def train_all_data(
501
  self,
 
445
  test_data_output_path = (
446
  Path(output_directory) / f"{output_prefix}_labeled_test"
447
  ).with_suffix(".dataset")
448
+ data_dict["train"].save_to_disk(str(train_data_output_path))
449
+ data_dict["test"].save_to_disk(str(test_data_output_path))
450
  elif (test_size is not None) and (self.classifier == "cell"):
451
  if 1 > test_size > 0:
452
  if attr_to_split is None:
 
461
  test_data_output_path = (
462
  Path(output_directory) / f"{output_prefix}_labeled_test"
463
  ).with_suffix(".dataset")
464
+ data_dict["train"].save_to_disk(str(train_data_output_path))
465
+ data_dict["test"].save_to_disk(str(test_data_output_path))
466
  else:
467
  data_dict, balance_df = cu.balance_attr_splits(
468
  data,
 
483
  test_data_output_path = (
484
  Path(output_directory) / f"{output_prefix}_labeled_test"
485
  ).with_suffix(".dataset")
486
+ data_dict["train"].save_to_disk(str(train_data_output_path))
487
+ data_dict["test"].save_to_disk(str(test_data_output_path))
488
  else:
489
  data_output_path = (
490
  Path(output_directory) / f"{output_prefix}_labeled"
491
  ).with_suffix(".dataset")
492
+ data.save_to_disk(str(data_output_path))
493
  print(data_output_path)
494
  else:
495
  data_output_path = (
496
  Path(output_directory) / f"{output_prefix}_labeled"
497
  ).with_suffix(".dataset")
498
+ data.save_to_disk(str(data_output_path))
499
 
500
  def train_all_data(
501
  self,
geneformer/tokenizer.py CHANGED
@@ -55,7 +55,6 @@ logger = logging.getLogger(__name__)
55
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
56
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
57
 
58
-
59
  def rank_genes(gene_vector, gene_tokens):
60
  """
61
  Rank gene expression vector.
@@ -176,7 +175,7 @@ class TranscriptomeTokenizer:
176
  )
177
 
178
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
179
- tokenized_dataset.save_to_disk(output_path)
180
 
181
  def tokenize_files(
182
  self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
 
55
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
56
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
57
 
 
58
  def rank_genes(gene_vector, gene_tokens):
59
  """
60
  Rank gene expression vector.
 
175
  )
176
 
177
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
178
+ tokenized_dataset.save_to_disk(str(output_path))
179
 
180
  def tokenize_files(
181
  self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"