Christina Theodoris commited on
Commit
50e921d
·
1 Parent(s): 268e566

Add sorting for aggregating data for goal state shifts

Browse files
geneformer/in_silico_perturber_stats.py CHANGED
@@ -37,15 +37,15 @@ def invert_dict(dictionary):
37
  return {v: k for k, v in dictionary.items()}
38
 
39
  # read raw dictionary files
40
- def read_dictionaries(dir, cell_or_gene_emb, anchor_token):
41
  file_found = 0
42
  file_path_list = []
43
  dict_list = []
44
- for file in os.listdir(dir):
45
  # process only _raw.pickle files
46
  if file.endswith("_raw.pickle"):
47
  file_found = 1
48
- file_path_list += [f"{dir}/{file}"]
49
  for file_path in tqdm(file_path_list):
50
  with open(file_path, "rb") as fp:
51
  cos_sims_dict = pickle.load(fp)
@@ -146,6 +146,10 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_
146
  if alt_end_state_exists == True:
147
  cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
148
  cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
 
 
 
 
149
  return cos_sims_full_df
150
 
151
  elif genes_perturbed == "all":
 
37
  return {v: k for k, v in dictionary.items()}
38
 
39
  # read raw dictionary files
40
+ def read_dictionaries(input_data_directory, cell_or_gene_emb, anchor_token):
41
  file_found = 0
42
  file_path_list = []
43
  dict_list = []
44
+ for file in os.listdir(input_data_directory):
45
  # process only _raw.pickle files
46
  if file.endswith("_raw.pickle"):
47
  file_found = 1
48
+ file_path_list += [f"{input_data_directory}/{file}"]
49
  for file_path in tqdm(file_path_list):
50
  with open(file_path, "rb") as fp:
51
  cos_sims_dict = pickle.load(fp)
 
146
  if alt_end_state_exists == True:
147
  cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
148
  cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
149
+
150
+ # sort by shift to desired state
151
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"],
152
+ ascending=[False])
153
  return cos_sims_full_df
154
 
155
  elif genes_perturbed == "all":