Jackmin108
commited on
Commit
·
c35a42b
1
Parent(s):
7af97e7
fix: when sentences is one
Browse filesSigned-off-by: Meow <[email protected]>
- modeling_lora.py +18 -9
modeling_lora.py
CHANGED
@@ -11,8 +11,11 @@ from torch.nn import Parameter
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
-
from .modeling_xlm_roberta import (
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def initialized_weights(
|
@@ -241,6 +244,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
241 |
"""
|
242 |
A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
|
243 |
"""
|
|
|
244 |
def __init__(
|
245 |
self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
|
246 |
):
|
@@ -262,7 +266,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
262 |
if (
|
263 |
not isinstance(self._task_instructions, dict)
|
264 |
or len(self._task_instructions) != len(self._lora_adaptations)
|
265 |
-
or not all(
|
|
|
|
|
266 |
):
|
267 |
raise ValueError(
|
268 |
f"`task_instructions` must be a dict and contain the same number of elements "
|
@@ -325,11 +331,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
325 |
config = XLMRobertaFlashConfig.from_pretrained(
|
326 |
pretrained_model_name_or_path, *model_args, **kwargs
|
327 |
)
|
328 |
-
if config.load_trained_adapters:
|
329 |
return super().from_pretrained(
|
330 |
pretrained_model_name_or_path, *model_args, **kwargs
|
331 |
)
|
332 |
-
else:
|
333 |
roberta = XLMRobertaModel.from_pretrained(
|
334 |
pretrained_model_name_or_path, *model_args, **kwargs
|
335 |
)
|
@@ -387,14 +393,17 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
387 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
388 |
)
|
389 |
adapter_mask = None
|
|
|
390 |
if task_type:
|
391 |
task_id = self._adaptation_map[task_type]
|
392 |
-
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
393 |
adapter_mask = torch.full(
|
394 |
-
(
|
395 |
)
|
396 |
-
if task_type in [
|
397 |
-
sentences = [
|
|
|
|
|
|
|
398 |
return self.roberta.encode(
|
399 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
400 |
)
|
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
+
from .modeling_xlm_roberta import (
|
15 |
+
XLMRobertaFlashConfig,
|
16 |
+
XLMRobertaModel,
|
17 |
+
XLMRobertaPreTrainedModel,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
def initialized_weights(
|
|
|
244 |
"""
|
245 |
A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
|
246 |
"""
|
247 |
+
|
248 |
def __init__(
|
249 |
self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
|
250 |
):
|
|
|
266 |
if (
|
267 |
not isinstance(self._task_instructions, dict)
|
268 |
or len(self._task_instructions) != len(self._lora_adaptations)
|
269 |
+
or not all(
|
270 |
+
[v in self._lora_adaptations for v in self._task_instructions.keys()]
|
271 |
+
)
|
272 |
):
|
273 |
raise ValueError(
|
274 |
f"`task_instructions` must be a dict and contain the same number of elements "
|
|
|
331 |
config = XLMRobertaFlashConfig.from_pretrained(
|
332 |
pretrained_model_name_or_path, *model_args, **kwargs
|
333 |
)
|
334 |
+
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
335 |
return super().from_pretrained(
|
336 |
pretrained_model_name_or_path, *model_args, **kwargs
|
337 |
)
|
338 |
+
else: # initializing new adapters
|
339 |
roberta = XLMRobertaModel.from_pretrained(
|
340 |
pretrained_model_name_or_path, *model_args, **kwargs
|
341 |
)
|
|
|
393 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
394 |
)
|
395 |
adapter_mask = None
|
396 |
+
sentences = list(sentences) if isinstance(sentences, str) else sentences
|
397 |
if task_type:
|
398 |
task_id = self._adaptation_map[task_type]
|
|
|
399 |
adapter_mask = torch.full(
|
400 |
+
(len(sentences),), task_id, dtype=torch.int32, device=self.device
|
401 |
)
|
402 |
+
if task_type in ["query", "passage"]:
|
403 |
+
sentences = [
|
404 |
+
self._task_instructions[task_type] + " " + sentence
|
405 |
+
for sentence in sentences
|
406 |
+
]
|
407 |
return self.roberta.encode(
|
408 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
409 |
)
|