comment out "def save_model_without_heads(original_model_save_directory)"; redundant for ISP/Emb extractor (#382)
Browse files- comment out "def save_model_without_heads(original_model_save_directory)"; redundant for ISP/Emb extractor (0a4e8a4ba076513934876cf83b1d3e727c26d46e)
Co-authored-by: Madhavan Venkatesh <[email protected]>
- geneformer/mtl/utils.py +38 -38
geneformer/mtl/utils.py
CHANGED
@@ -73,44 +73,44 @@ def calculate_combined_f1(combined_labels, combined_preds):
|
|
73 |
return f1, accuracy
|
74 |
|
75 |
|
76 |
-
def save_model_without_heads(original_model_save_directory):
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
|
115 |
|
116 |
def get_layer_freeze_range(pretrained_path):
|
|
|
73 |
return f1, accuracy
|
74 |
|
75 |
|
76 |
+
# def save_model_without_heads(original_model_save_directory):
|
77 |
+
# # Create a new directory for the model without heads
|
78 |
+
# new_model_save_directory = original_model_save_directory + "_No_Heads"
|
79 |
+
# if not os.path.exists(new_model_save_directory):
|
80 |
+
# os.makedirs(new_model_save_directory)
|
81 |
+
|
82 |
+
# # Load the model state dictionary
|
83 |
+
# model_state_dict = torch.load(
|
84 |
+
# os.path.join(original_model_save_directory, "pytorch_model.bin")
|
85 |
+
# )
|
86 |
+
|
87 |
+
# # Initialize a new BERT model without the classification heads
|
88 |
+
# config = BertConfig.from_pretrained(
|
89 |
+
# os.path.join(original_model_save_directory, "config.json")
|
90 |
+
# )
|
91 |
+
# model_without_heads = BertModel(config)
|
92 |
+
|
93 |
+
# # Filter the state dict to exclude classification heads
|
94 |
+
# model_without_heads_state_dict = {
|
95 |
+
# k: v
|
96 |
+
# for k, v in model_state_dict.items()
|
97 |
+
# if not k.startswith("classification_heads")
|
98 |
+
# }
|
99 |
+
|
100 |
+
# # Load the filtered state dict into the model
|
101 |
+
# model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
|
102 |
+
|
103 |
+
# # Save the model without heads
|
104 |
+
# model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
|
105 |
+
# torch.save(model_without_heads.state_dict(), model_save_path)
|
106 |
+
|
107 |
+
# # Copy the configuration file
|
108 |
+
# shutil.copy(
|
109 |
+
# os.path.join(original_model_save_directory, "config.json"),
|
110 |
+
# new_model_save_directory,
|
111 |
+
# )
|
112 |
+
|
113 |
+
# print(f"Model without classification heads saved to {new_model_save_directory}")
|
114 |
|
115 |
|
116 |
def get_layer_freeze_range(pretrained_path):
|