suayptalha commited on
Commit
9359333
·
verified ·
1 Parent(s): 84e7a9a

Update modeling_minGRU.py

Browse files
Files changed (1) hide show
  1. modeling_minGRU.py +6 -4
modeling_minGRU.py CHANGED
@@ -109,13 +109,14 @@ class MinGRUForSequenceClassification(PreTrainedModel):
109
  model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
110
  return model
111
 
112
- def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
113
  """
114
  Save the model and configuration to a directory.
115
 
116
  Args:
117
  save_directory (str): Directory to save the model.
118
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
 
119
  """
120
  import os
121
  os.makedirs(save_directory, exist_ok=True)
@@ -124,13 +125,13 @@ class MinGRUForSequenceClassification(PreTrainedModel):
124
  print("Saving with safe serialization.")
125
 
126
  state_dict = {}
127
-
128
  for name, param in self.model.min_gru_model.named_parameters():
129
  state_dict[f"model.{name}"] = param
130
-
131
  for name, param in self.classifier.named_parameters():
132
  state_dict[f"classifier.{name}"] = param
133
-
134
  state_dict['config'] = self.config.__dict__
135
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
136
 
@@ -138,3 +139,4 @@ class MinGRUForSequenceClassification(PreTrainedModel):
138
  else:
139
  print("Saving without safe serialization.")
140
  super().save_pretrained(save_directory)
 
 
109
  model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
110
  return model
111
 
112
+ def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
113
  """
114
  Save the model and configuration to a directory.
115
 
116
  Args:
117
  save_directory (str): Directory to save the model.
118
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
119
+ kwargs: Additional arguments like max_shard_size (ignored in this implementation).
120
  """
121
  import os
122
  os.makedirs(save_directory, exist_ok=True)
 
125
  print("Saving with safe serialization.")
126
 
127
  state_dict = {}
128
+
129
  for name, param in self.model.min_gru_model.named_parameters():
130
  state_dict[f"model.{name}"] = param
131
+
132
  for name, param in self.classifier.named_parameters():
133
  state_dict[f"classifier.{name}"] = param
134
+
135
  state_dict['config'] = self.config.__dict__
136
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
137
 
 
139
  else:
140
  print("Saving without safe serialization.")
141
  super().save_pretrained(save_directory)
142
+