jupyterjazz
commited on
Commit
·
3eb20d0
1
Parent(s):
509511d
refactor: modify encode
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- modeling_lora.py +7 -9
- modeling_xlm_roberta.py +5 -2
modeling_lora.py
CHANGED
@@ -337,7 +337,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
337 |
def encode(
|
338 |
self,
|
339 |
*args,
|
340 |
-
task:
|
341 |
**kwargs,
|
342 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
343 |
"""
|
@@ -351,13 +351,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
351 |
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
352 |
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
353 |
"""
|
354 |
-
if task
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
)
|
361 |
-
self.current_task = task
|
362 |
|
363 |
return self.roberta.encode(*args, **kwargs)
|
|
|
337 |
def encode(
|
338 |
self,
|
339 |
*args,
|
340 |
+
task: Optional[str] = None,
|
341 |
**kwargs,
|
342 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
343 |
"""
|
|
|
351 |
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
352 |
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
353 |
"""
|
354 |
+
if task and task not in self._lora_adaptations:
|
355 |
+
raise ValueError(
|
356 |
+
f"Unsupported task '{task}'. "
|
357 |
+
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
358 |
+
f"Alternatively, don't pass the `task` argument to disable LoRA."
|
359 |
+
)
|
|
|
|
|
360 |
|
361 |
return self.roberta.encode(*args, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
@@ -459,6 +459,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
459 |
device: Optional[torch.device] = None,
|
460 |
normalize_embeddings: bool = False,
|
461 |
truncate_dim: Optional[int] = None,
|
|
|
462 |
**tokenizer_kwargs,
|
463 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
464 |
"""
|
@@ -549,14 +550,16 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
549 |
)
|
550 |
else:
|
551 |
range_iter = range(0, len(sentences), batch_size)
|
552 |
-
|
|
|
|
|
553 |
for i in range_iter:
|
554 |
encoded_input = self.tokenizer(
|
555 |
sentences[i : i + batch_size],
|
556 |
return_tensors='pt',
|
557 |
**tokenizer_kwargs,
|
558 |
).to(self.device)
|
559 |
-
token_embs = self.forward(**encoded_input)[0]
|
560 |
|
561 |
# Accumulate in fp32 to avoid overflow
|
562 |
token_embs = token_embs.float()
|
|
|
459 |
device: Optional[torch.device] = None,
|
460 |
normalize_embeddings: bool = False,
|
461 |
truncate_dim: Optional[int] = None,
|
462 |
+
task: Optional[str] = None,
|
463 |
**tokenizer_kwargs,
|
464 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
465 |
"""
|
|
|
550 |
)
|
551 |
else:
|
552 |
range_iter = range(0, len(sentences), batch_size)
|
553 |
+
lora_kwargs = {}
|
554 |
+
if task:
|
555 |
+
lora_kwargs['task'] = task
|
556 |
for i in range_iter:
|
557 |
encoded_input = self.tokenizer(
|
558 |
sentences[i : i + batch_size],
|
559 |
return_tensors='pt',
|
560 |
**tokenizer_kwargs,
|
561 |
).to(self.device)
|
562 |
+
token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
|
563 |
|
564 |
# Accumulate in fp32 to avoid overflow
|
565 |
token_embs = token_embs.float()
|