[Fix bug] TypeError: argument of type 'XLMRobertaFlashConfig' is not iterable
Browse files- modeling_lora.py +15 -13
modeling_lora.py
CHANGED
@@ -11,16 +11,12 @@ from torch.nn import Parameter
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
-
from .rotary import RotaryEmbedding
|
15 |
-
from .mlp import FusedMLP, Mlp
|
16 |
-
from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
|
17 |
-
from .stochastic_depth import stochastic_depth
|
18 |
-
from .mha import MHA
|
19 |
-
from .block import Block
|
20 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
21 |
-
from .
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
|
25 |
|
26 |
def initialized_weights(
|
@@ -336,7 +332,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
336 |
**kwargs,
|
337 |
):
|
338 |
for key in list(kwargs.keys()):
|
339 |
-
if key in config:
|
340 |
config.update({key: kwargs.pop(key)})
|
341 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
342 |
return super().from_pretrained(
|
@@ -350,11 +346,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
350 |
token=token,
|
351 |
revision=revision,
|
352 |
use_safetensors=use_safetensors,
|
353 |
-
**kwargs
|
354 |
)
|
355 |
else: # initializing new adapters
|
356 |
roberta = XLMRobertaModel.from_pretrained(
|
357 |
-
pretrained_model_name_or_path,
|
|
|
|
|
|
|
358 |
)
|
359 |
return cls(config, roberta=roberta)
|
360 |
|
@@ -418,7 +417,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
418 |
if isinstance(sentences, str):
|
419 |
sentences = self._task_instructions[task] + sentences
|
420 |
else:
|
421 |
-
sentences = [
|
|
|
|
|
422 |
return self.roberta.encode(
|
423 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
424 |
)
|
|
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
15 |
+
from .modeling_xlm_roberta import (
|
16 |
+
XLMRobertaFlashConfig,
|
17 |
+
XLMRobertaModel,
|
18 |
+
XLMRobertaPreTrainedModel,
|
19 |
+
)
|
20 |
|
21 |
|
22 |
def initialized_weights(
|
|
|
332 |
**kwargs,
|
333 |
):
|
334 |
for key in list(kwargs.keys()):
|
335 |
+
if key in config.to_dict():
|
336 |
config.update({key: kwargs.pop(key)})
|
337 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
338 |
return super().from_pretrained(
|
|
|
346 |
token=token,
|
347 |
revision=revision,
|
348 |
use_safetensors=use_safetensors,
|
349 |
+
**kwargs,
|
350 |
)
|
351 |
else: # initializing new adapters
|
352 |
roberta = XLMRobertaModel.from_pretrained(
|
353 |
+
pretrained_model_name_or_path,
|
354 |
+
*model_args,
|
355 |
+
use_flash_attn=config.use_flash_attn,
|
356 |
+
**kwargs,
|
357 |
)
|
358 |
return cls(config, roberta=roberta)
|
359 |
|
|
|
417 |
if isinstance(sentences, str):
|
418 |
sentences = self._task_instructions[task] + sentences
|
419 |
else:
|
420 |
+
sentences = [
|
421 |
+
self._task_instructions[task] + sentence for sentence in sentences
|
422 |
+
]
|
423 |
return self.roberta.encode(
|
424 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
425 |
)
|
426 |
+
|