Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn jxm commited on
Commit
c1b1fd7
·
verified ·
1 Parent(s): 53c3a30

add full support for inputs_embeds (#10)

Browse files

- add full support for inputs_embeds (2fd43c9a0641a75fa975f3257d97e5c55b3fa940)


Co-authored-by: Jack Morris <[email protected]>

Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +5 -8
modeling_hf_nomic_bert.py CHANGED
@@ -983,22 +983,21 @@ class NomicBertEmbeddings(nn.Module):
983
  position_ids: (batch, seqlen)
984
  token_type_ids: (batch, seqlen)
985
  """
986
- batch_size, seqlen = input_ids.shape
987
-
988
  if inputs_embeds is None:
989
  embeddings = self.word_embeddings(input_ids)
990
  else:
991
  embeddings = inputs_embeds
992
-
 
993
  if self.type_vocab_size > 0:
994
  if token_type_ids is None:
995
- token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
996
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
997
  embeddings = embeddings + token_type_embeddings
998
 
999
  if self.max_position_embeddings > 0:
1000
  if position_ids is None:
1001
- position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
1002
  position_embeddings = self.position_embeddings(position_ids)
1003
  embeddings = embeddings + position_embeddings
1004
  return embeddings
@@ -1688,8 +1687,6 @@ class NomicBertModel(NomicBertPreTrainedModel):
1688
  ):
1689
  if input_ids is not None and inputs_embeds is not None:
1690
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1691
- if token_type_ids is None:
1692
- token_type_ids = torch.zeros_like(input_ids)
1693
  hidden_states = self.embeddings(
1694
  input_ids=input_ids,
1695
  position_ids=position_ids,
@@ -1699,7 +1696,7 @@ class NomicBertModel(NomicBertPreTrainedModel):
1699
  hidden_states = self.emb_ln(hidden_states)
1700
  hidden_states = self.emb_drop(hidden_states)
1701
 
1702
- attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1703
  sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1704
 
1705
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
 
983
  position_ids: (batch, seqlen)
984
  token_type_ids: (batch, seqlen)
985
  """
 
 
986
  if inputs_embeds is None:
987
  embeddings = self.word_embeddings(input_ids)
988
  else:
989
  embeddings = inputs_embeds
990
+ batch_size, seqlen, _ = embeddings.shape
991
+
992
  if self.type_vocab_size > 0:
993
  if token_type_ids is None:
994
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device)
995
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
996
  embeddings = embeddings + token_type_embeddings
997
 
998
  if self.max_position_embeddings > 0:
999
  if position_ids is None:
1000
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device)
1001
  position_embeddings = self.position_embeddings(position_ids)
1002
  embeddings = embeddings + position_embeddings
1003
  return embeddings
 
1687
  ):
1688
  if input_ids is not None and inputs_embeds is not None:
1689
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
1690
  hidden_states = self.embeddings(
1691
  input_ids=input_ids,
1692
  position_ids=position_ids,
 
1696
  hidden_states = self.emb_ln(hidden_states)
1697
  hidden_states = self.emb_drop(hidden_states)
1698
 
1699
+ attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1])
1700
  sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1701
 
1702
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None