suayptalha
commited on
Update modeling_minGRU.py
Browse files- modeling_minGRU.py +6 -4
modeling_minGRU.py
CHANGED
@@ -109,13 +109,14 @@ class MinGRUForSequenceClassification(PreTrainedModel):
|
|
109 |
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
110 |
return model
|
111 |
|
112 |
-
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
|
113 |
"""
|
114 |
Save the model and configuration to a directory.
|
115 |
|
116 |
Args:
|
117 |
save_directory (str): Directory to save the model.
|
118 |
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
|
|
|
119 |
"""
|
120 |
import os
|
121 |
os.makedirs(save_directory, exist_ok=True)
|
@@ -124,13 +125,13 @@ class MinGRUForSequenceClassification(PreTrainedModel):
|
|
124 |
print("Saving with safe serialization.")
|
125 |
|
126 |
state_dict = {}
|
127 |
-
|
128 |
for name, param in self.model.min_gru_model.named_parameters():
|
129 |
state_dict[f"model.{name}"] = param
|
130 |
-
|
131 |
for name, param in self.classifier.named_parameters():
|
132 |
state_dict[f"classifier.{name}"] = param
|
133 |
-
|
134 |
state_dict['config'] = self.config.__dict__
|
135 |
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
136 |
|
@@ -138,3 +139,4 @@ class MinGRUForSequenceClassification(PreTrainedModel):
|
|
138 |
else:
|
139 |
print("Saving without safe serialization.")
|
140 |
super().save_pretrained(save_directory)
|
|
|
|
109 |
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
110 |
return model
|
111 |
|
112 |
+
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
|
113 |
"""
|
114 |
Save the model and configuration to a directory.
|
115 |
|
116 |
Args:
|
117 |
save_directory (str): Directory to save the model.
|
118 |
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
|
119 |
+
kwargs: Additional arguments like max_shard_size (ignored in this implementation).
|
120 |
"""
|
121 |
import os
|
122 |
os.makedirs(save_directory, exist_ok=True)
|
|
|
125 |
print("Saving with safe serialization.")
|
126 |
|
127 |
state_dict = {}
|
128 |
+
|
129 |
for name, param in self.model.min_gru_model.named_parameters():
|
130 |
state_dict[f"model.{name}"] = param
|
131 |
+
|
132 |
for name, param in self.classifier.named_parameters():
|
133 |
state_dict[f"classifier.{name}"] = param
|
134 |
+
|
135 |
state_dict['config'] = self.config.__dict__
|
136 |
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
137 |
|
|
|
139 |
else:
|
140 |
print("Saving without safe serialization.")
|
141 |
super().save_pretrained(save_directory)
|
142 |
+
|