Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
b9c36f5
·
verified ·
1 Parent(s): 5d66d02

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +15 -21
modeling_hf_nomic_bert.py CHANGED
@@ -105,13 +105,7 @@ def filter_shapes(state_dict, model):
105
  return filtered_state_dict
106
 
107
 
108
- def remap_bert_state_dict(
109
- state_dict,
110
- config,
111
- remove_bert=False,
112
- remove_cls_weights=False,
113
- add_pooling_layer=False,
114
- ):
115
  """
116
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
117
  """
@@ -311,12 +305,13 @@ class NomicBertPreTrainedModel(PreTrainedModel):
311
  if config is None:
312
  config = cls.config_class.from_pretrained(model_name)
313
  remove_cls = cls != NomicBertForPreTraining
314
- remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
315
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
  num_labels = kwargs.pop("num_labels", None)
317
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
- strict = kwargs.pop("strict", True)
319
- config.rotary_scaling_factor = rotary_scaling_factor
 
320
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
321
  config.n_positions = 2048
322
  if num_labels:
@@ -325,7 +320,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
325
  if "add_pooling_layer" in kwargs:
326
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
327
  else:
328
- model = cls(config, *inputs)
 
 
 
329
  # TODO: fix this
330
  # Assuming we know what we're doing when loading from disk
331
  # Prob a bad assumption but i'm tired and want to train this asap
@@ -344,7 +342,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
344
  load_return = model.load_state_dict(state_dict, strict=False)
345
  else:
346
  # TODO: can probably check config class and see if we need to remap from a bert model
347
- state_dict = state_dict_from_pretrained(model_name)
348
  state_dict = remap_bert_state_dict(
349
  state_dict,
350
  config,
@@ -355,7 +353,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
355
  if ignore_mismatched_shapes:
356
  state_dict = filter_shapes(state_dict, model)
357
 
358
- load_return = model.load_state_dict(state_dict, strict=strict)
359
  logger.warning(load_return)
360
  return model
361
 
@@ -726,7 +724,7 @@ class NomicBertAttention(nn.Module):
726
 
727
  self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
728
  if self.rotary_emb_dim > 0:
729
- if getattr(config, "rotary_scaling_factor", None):
730
  self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
731
  dim=self.rotary_emb_dim,
732
  base=config.rotary_emb_base,
@@ -1057,11 +1055,10 @@ class NomicBertModel(NomicBertPreTrainedModel):
1057
  def forward(
1058
  self,
1059
  input_ids,
1060
- position_ids=None,
1061
- token_type_ids=None,
1062
  attention_mask=None,
 
 
1063
  return_dict=None,
1064
- matryoshka_dim=None,
1065
  ):
1066
  if token_type_ids is None:
1067
  token_type_ids = torch.zeros_like(input_ids)
@@ -1074,9 +1071,6 @@ class NomicBertModel(NomicBertPreTrainedModel):
1074
 
1075
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1076
 
1077
- if matryoshka_dim:
1078
- sequence_output = sequence_output[:, :matryoshka_dim]
1079
-
1080
  return BaseModelOutputWithPoolingAndCrossAttentions(
1081
  last_hidden_state=sequence_output,
1082
  pooler_output=pooled_output,
@@ -1224,4 +1218,4 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1224
  logits=logits,
1225
  hidden_states=outputs.hidden_states,
1226
  attentions=outputs.attentions,
1227
- )
 
105
  return filtered_state_dict
106
 
107
 
108
+ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
 
 
 
 
 
 
109
  """
110
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
111
  """
 
305
  if config is None:
306
  config = cls.config_class.from_pretrained(model_name)
307
  remove_cls = cls != NomicBertForPreTraining
308
+ remove_bert_prefix = cls != NomicBertForPreTraining
309
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
310
  num_labels = kwargs.pop("num_labels", None)
311
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
312
+ if rotary_scaling_factor:
313
+ config.rotary_scaling_factor = rotary_scaling_factor
314
+
315
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
316
  config.n_positions = 2048
317
  if num_labels:
 
320
  if "add_pooling_layer" in kwargs:
321
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
322
  else:
323
+ if cls == NomicBertModel:
324
+ model = cls(config, *inputs, add_pooling_layer=False)
325
+ else:
326
+ model = cls(config, *inputs)
327
  # TODO: fix this
328
  # Assuming we know what we're doing when loading from disk
329
  # Prob a bad assumption but i'm tired and want to train this asap
 
342
  load_return = model.load_state_dict(state_dict, strict=False)
343
  else:
344
  # TODO: can probably check config class and see if we need to remap from a bert model
345
+ state_dict = state_dict_from_pretrained(model_name, safe_serialization=kwargs.get("safe_serialization", False))
346
  state_dict = remap_bert_state_dict(
347
  state_dict,
348
  config,
 
353
  if ignore_mismatched_shapes:
354
  state_dict = filter_shapes(state_dict, model)
355
 
356
+ load_return = model.load_state_dict(state_dict, strict=True)
357
  logger.warning(load_return)
358
  return model
359
 
 
724
 
725
  self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
726
  if self.rotary_emb_dim > 0:
727
+ if config.rotary_scaling_factor:
728
  self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
729
  dim=self.rotary_emb_dim,
730
  base=config.rotary_emb_base,
 
1055
  def forward(
1056
  self,
1057
  input_ids,
 
 
1058
  attention_mask=None,
1059
+ token_type_ids=None,
1060
+ position_ids=None,
1061
  return_dict=None,
 
1062
  ):
1063
  if token_type_ids is None:
1064
  token_type_ids = torch.zeros_like(input_ids)
 
1071
 
1072
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1073
 
 
 
 
1074
  return BaseModelOutputWithPoolingAndCrossAttentions(
1075
  last_hidden_state=sequence_output,
1076
  pooler_output=pooled_output,
 
1218
  logits=logits,
1219
  hidden_states=outputs.hidden_states,
1220
  attentions=outputs.attentions,
1221
+ )