jupyterjazz Jackmin108 commited on
Commit
7c4a80c
·
verified ·
1 Parent(s): 98c3cd2

lora bugfix (#16)

Browse files

- fix: lora bug (4c504d33aca884a998533b089cb905b597a82467)


Co-authored-by: Jack Min Ong <[email protected]>

Files changed (1) hide show
  1. modeling_lora.py +13 -8
modeling_lora.py CHANGED
@@ -11,7 +11,7 @@ from torch import nn
11
  from torch.nn import Parameter
12
  from transformers import PretrainedConfig
13
 
14
- from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
15
 
16
 
17
  LORA_NO_UPDATE = '__lora_no_update__'
@@ -210,13 +210,19 @@ class LoRAParametrization(nn.Module):
210
  layer.current_task = task_idx
211
 
212
 
213
- class XLMRobertaLoRA(XLMRobertaModel):
214
  def __init__(
215
  self,
216
  config: XLMRobertaFlashConfig,
 
217
  ):
218
  super().__init__(config)
219
 
 
 
 
 
 
220
  self._lora_adaptations = config.lora_adaptations
221
  if (
222
  not isinstance(self._lora_adaptations, list)
@@ -231,7 +237,6 @@ class XLMRobertaLoRA(XLMRobertaModel):
231
  self._rank = config.lora_rank
232
  self._dropout_p = config.lora_dropout_p
233
  self._alpha = config.lora_alpha
234
-
235
  self._register_lora(
236
  num_adaptations=len(self._lora_adaptations),
237
  rank=self._rank,
@@ -284,9 +289,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
284
  pretrained_model_name_or_path, *model_args, **kwargs
285
  )
286
  else:
287
- dtype = config.torch_dtype if config.torch_dtype else torch.bfloat16
288
- torch.set_default_dtype(dtype)
289
- return cls(config)
290
 
291
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
292
  self.apply(
@@ -331,7 +335,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
331
  def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
332
  if task != LORA_NO_UPDATE:
333
  self.current_task = task
334
- return super().forward(*args, **kwargs)
 
335
 
336
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
337
  for _, param in self.named_parameters(recurse=recurse):
@@ -373,4 +378,4 @@ class XLMRobertaLoRA(XLMRobertaModel):
373
  )
374
  self.current_task = task
375
 
376
- return super().encode(*args, **kwargs)
 
11
  from torch.nn import Parameter
12
  from transformers import PretrainedConfig
13
 
14
+ from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
 
16
 
17
  LORA_NO_UPDATE = '__lora_no_update__'
 
210
  layer.current_task = task_idx
211
 
212
 
213
+ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
214
  def __init__(
215
  self,
216
  config: XLMRobertaFlashConfig,
217
+ roberta: Optional[XLMRobertaModel] = None
218
  ):
219
  super().__init__(config)
220
 
221
+ if roberta is None:
222
+ self.roberta = XLMRobertaModel(config)
223
+ else:
224
+ self.roberta = roberta
225
+
226
  self._lora_adaptations = config.lora_adaptations
227
  if (
228
  not isinstance(self._lora_adaptations, list)
 
237
  self._rank = config.lora_rank
238
  self._dropout_p = config.lora_dropout_p
239
  self._alpha = config.lora_alpha
 
240
  self._register_lora(
241
  num_adaptations=len(self._lora_adaptations),
242
  rank=self._rank,
 
289
  pretrained_model_name_or_path, *model_args, **kwargs
290
  )
291
  else:
292
+ roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
293
+ return cls(config, roberta=roberta)
 
294
 
295
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
296
  self.apply(
 
335
  def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
336
  if task != LORA_NO_UPDATE:
337
  self.current_task = task
338
+
339
+ return self.roberta(*args, **kwargs)
340
 
341
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
342
  for _, param in self.named_parameters(recurse=recurse):
 
378
  )
379
  self.current_task = task
380
 
381
+ return self.roberta.encode(*args, **kwargs)