jupyterjazz commited on
Commit
f8300e5
·
1 Parent(s): da863dd

refactor: normalization after truncation

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. modeling_xlm_roberta.py +5 -7
modeling_xlm_roberta.py CHANGED
@@ -588,12 +588,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
588
  embeddings = self.mean_pooling(
589
  token_embs, encoded_input["attention_mask"]
590
  )
591
-
592
- if normalize_embeddings:
593
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
594
-
595
- if convert_to_numpy:
596
- embeddings = embeddings.cpu()
597
  all_embeddings.extend(embeddings)
598
 
599
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
@@ -601,11 +596,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
601
  truncate_dim = truncate_dim or self.config.truncate_dim
602
  if truncate_dim:
603
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
 
 
 
604
 
605
  if convert_to_tensor:
606
  all_embeddings = torch.stack(all_embeddings)
607
  elif convert_to_numpy:
608
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
609
 
610
  if input_was_string:
611
  all_embeddings = all_embeddings[0]
 
588
  embeddings = self.mean_pooling(
589
  token_embs, encoded_input["attention_mask"]
590
  )
591
+
 
 
 
 
 
592
  all_embeddings.extend(embeddings)
593
 
594
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
 
596
  truncate_dim = truncate_dim or self.config.truncate_dim
597
  if truncate_dim:
598
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
599
+
600
+ if normalize_embeddings:
601
+ all_embeddings = [torch.nn.functional.normalize(embedding, p=2, dim=0) for embedding in all_embeddings]
602
 
603
  if convert_to_tensor:
604
  all_embeddings = torch.stack(all_embeddings)
605
  elif convert_to_numpy:
606
+ all_embeddings = np.asarray([emb.cpu().numpy() for emb in all_embeddings])
607
 
608
  if input_was_string:
609
  all_embeddings = all_embeddings[0]