phuonglk commited on
Commit
7207e6d
·
verified ·
1 Parent(s): 82b68d6

[Fix bug] TypeError: argument of type 'XLMRobertaFlashConfig' is not iterable

Browse files
Files changed (1) hide show
  1. modeling_lora.py +15 -13
modeling_lora.py CHANGED
@@ -11,16 +11,12 @@ from torch.nn import Parameter
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
- from .rotary import RotaryEmbedding
15
- from .mlp import FusedMLP, Mlp
16
- from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
17
- from .stochastic_depth import stochastic_depth
18
- from .mha import MHA
19
- from .block import Block
20
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
21
- from .embedding import XLMRobertaEmbeddings
22
- from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
23
- XLMRobertaPreTrainedModel)
 
 
24
 
25
 
26
  def initialized_weights(
@@ -336,7 +332,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
336
  **kwargs,
337
  ):
338
  for key in list(kwargs.keys()):
339
- if key in config:
340
  config.update({key: kwargs.pop(key)})
341
  if config.load_trained_adapters: # checkpoint already contains LoRA adapters
342
  return super().from_pretrained(
@@ -350,11 +346,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
350
  token=token,
351
  revision=revision,
352
  use_safetensors=use_safetensors,
353
- **kwargs
354
  )
355
  else: # initializing new adapters
356
  roberta = XLMRobertaModel.from_pretrained(
357
- pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
 
 
 
358
  )
359
  return cls(config, roberta=roberta)
360
 
@@ -418,7 +417,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
418
  if isinstance(sentences, str):
419
  sentences = self._task_instructions[task] + sentences
420
  else:
421
- sentences = [self._task_instructions[task] + sentence for sentence in sentences]
 
 
422
  return self.roberta.encode(
423
  sentences, *args, adapter_mask=adapter_mask, **kwargs
424
  )
 
 
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
 
 
 
 
 
 
14
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
15
+ from .modeling_xlm_roberta import (
16
+ XLMRobertaFlashConfig,
17
+ XLMRobertaModel,
18
+ XLMRobertaPreTrainedModel,
19
+ )
20
 
21
 
22
  def initialized_weights(
 
332
  **kwargs,
333
  ):
334
  for key in list(kwargs.keys()):
335
+ if key in config.to_dict():
336
  config.update({key: kwargs.pop(key)})
337
  if config.load_trained_adapters: # checkpoint already contains LoRA adapters
338
  return super().from_pretrained(
 
346
  token=token,
347
  revision=revision,
348
  use_safetensors=use_safetensors,
349
+ **kwargs,
350
  )
351
  else: # initializing new adapters
352
  roberta = XLMRobertaModel.from_pretrained(
353
+ pretrained_model_name_or_path,
354
+ *model_args,
355
+ use_flash_attn=config.use_flash_attn,
356
+ **kwargs,
357
  )
358
  return cls(config, roberta=roberta)
359
 
 
417
  if isinstance(sentences, str):
418
  sentences = self._task_instructions[task] + sentences
419
  else:
420
+ sentences = [
421
+ self._task_instructions[task] + sentence for sentence in sentences
422
+ ]
423
  return self.roberta.encode(
424
  sentences, *args, adapter_mask=adapter_mask, **kwargs
425
  )
426
+