jupyterjazz
commited on
truncate-embedding-dimension (#10)
Browse files- feat: matryoshka embeddings (ee8863c9cbae496b327c9d5761b54f02a2f90954)
- refactor: optional arg (3f72891549ab3f1a6a3cd4e8b40dff6d5c50d1b1)
- fix: var name (fd34c40e6fcbb638e225ccd8b47f5b9c487bd8a4)
- fix: another one (b27fa557459cf35a2520c39da441b5e79e455068)
- refactor: truncation fn (c55e59156fa5b02100f7a7707324f3ce4f92714f)
- feat: truncation option during init (943cec246f8df968b3c6b2bd10e89f9529797b25)
- configuration_xlm_roberta.py +4 -0
- modeling_xlm_roberta.py +20 -0
configuration_xlm_roberta.py
CHANGED
@@ -31,6 +31,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
31 |
use_flash_attn=True,
|
32 |
torch_dtype=None,
|
33 |
emb_pooler=None,
|
|
|
|
|
34 |
**kwargs,
|
35 |
):
|
36 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
@@ -59,6 +61,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
59 |
self.lora_main_params_trainable = lora_main_params_trainable
|
60 |
self.use_flash_attn = use_flash_attn
|
61 |
self.emb_pooler = emb_pooler
|
|
|
|
|
62 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
63 |
self.torch_dtype = getattr(torch, torch_dtype)
|
64 |
else:
|
|
|
31 |
use_flash_attn=True,
|
32 |
torch_dtype=None,
|
33 |
emb_pooler=None,
|
34 |
+
matryoshka_dimensions=None,
|
35 |
+
truncate_dim=None,
|
36 |
**kwargs,
|
37 |
):
|
38 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
61 |
self.lora_main_params_trainable = lora_main_params_trainable
|
62 |
self.use_flash_attn = use_flash_attn
|
63 |
self.emb_pooler = emb_pooler
|
64 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
65 |
+
self.truncate_dim = truncate_dim
|
66 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
67 |
self.torch_dtype = getattr(torch, torch_dtype)
|
68 |
else:
|
modeling_xlm_roberta.py
CHANGED
@@ -452,6 +452,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
452 |
convert_to_tensor: bool = False,
|
453 |
device: Optional[torch.device] = None,
|
454 |
normalize_embeddings: bool = False,
|
|
|
455 |
**tokenizer_kwargs,
|
456 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
457 |
"""
|
@@ -481,6 +482,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
481 |
If set to true, returned vectors will have length 1. In that case, the
|
482 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
483 |
be used.
|
|
|
|
|
484 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
485 |
Keyword arguments for the tokenizer
|
486 |
Returns:
|
@@ -575,6 +578,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
575 |
|
576 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
577 |
|
|
|
|
|
|
|
|
|
578 |
if convert_to_tensor:
|
579 |
all_embeddings = torch.stack(all_embeddings)
|
580 |
elif convert_to_numpy:
|
@@ -586,6 +593,19 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
586 |
self.train(is_training)
|
587 |
return all_embeddings
|
588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
589 |
def mean_pooling(
|
590 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
591 |
):
|
|
|
452 |
convert_to_tensor: bool = False,
|
453 |
device: Optional[torch.device] = None,
|
454 |
normalize_embeddings: bool = False,
|
455 |
+
truncate_dim: Optional[int] = None,
|
456 |
**tokenizer_kwargs,
|
457 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
458 |
"""
|
|
|
482 |
If set to true, returned vectors will have length 1. In that case, the
|
483 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
484 |
be used.
|
485 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
486 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
487 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
488 |
Keyword arguments for the tokenizer
|
489 |
Returns:
|
|
|
578 |
|
579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
580 |
|
581 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
582 |
+
if truncate_dim:
|
583 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
584 |
+
|
585 |
if convert_to_tensor:
|
586 |
all_embeddings = torch.stack(all_embeddings)
|
587 |
elif convert_to_numpy:
|
|
|
593 |
self.train(is_training)
|
594 |
return all_embeddings
|
595 |
|
596 |
+
|
597 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
598 |
+
if not self.config.matryoshka_dimensions:
|
599 |
+
logger.warning(
|
600 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
601 |
+
)
|
602 |
+
return embeddings
|
603 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
604 |
+
return [tensor[:truncate_dim] for tensor in embeddings]
|
605 |
+
else:
|
606 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
607 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
608 |
+
|
609 |
def mean_pooling(
|
610 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
611 |
):
|