jupyterjazz commited on
Commit
a2b7c86
·
1 Parent(s): c6a5a4d

refactor: restructure the class

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. modeling_lora.py +6 -13
modeling_lora.py CHANGED
@@ -12,7 +12,6 @@ from transformers import PretrainedConfig
12
  from .modeling_xlm_roberta import (
13
  XLMRobertaFlashConfig,
14
  XLMRobertaModel,
15
- XLMRobertaPreTrainedModel,
16
  )
17
 
18
 
@@ -209,19 +208,13 @@ class LoRAParametrization(nn.Module):
209
  layer.current_task = task_idx
210
 
211
 
212
- class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
213
  def __init__(
214
  self,
215
  config: XLMRobertaFlashConfig,
216
- roberta: Optional[XLMRobertaModel] = None,
217
  ):
218
  super().__init__(config)
219
 
220
- if roberta is None:
221
- self.roberta = XLMRobertaModel(config)
222
- else:
223
- self.roberta = roberta
224
-
225
  self._num_adaptations = len(config.lora_adaptations)
226
  self._rank = config.lora_rank
227
  self._dropout_p = config.lora_dropout_p
@@ -238,6 +231,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
238
  # By default, we select the first LoRA
239
  self.current_task = 0
240
 
 
241
  @property
242
  def main_params_trainable(self):
243
  return self._main_params_trainable
@@ -273,15 +267,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
273
  config = XLMRobertaFlashConfig.from_pretrained(
274
  pretrained_model_name_or_path, *model_args, **kwargs
275
  )
 
276
  if config.load_trained_adapters:
277
  return super().from_pretrained(
278
  pretrained_model_name_or_path, *model_args, **kwargs
279
  )
280
  else:
281
- roberta = XLMRobertaModel.from_pretrained(
282
- pretrained_model_name_or_path, *model_args, **kwargs
283
- )
284
- return cls(config, roberta=roberta)
285
 
286
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
287
  self.apply(
@@ -320,7 +313,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
320
  def forward(self, *args, lora_adaptation: Union[None, int] = -1, **kwargs):
321
  if lora_adaptation is None or lora_adaptation >= 0:
322
  self.current_task = lora_adaptation
323
- return self.roberta(*args, **kwargs)
324
 
325
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
326
  for _, param in self.named_parameters(recurse=recurse):
 
12
  from .modeling_xlm_roberta import (
13
  XLMRobertaFlashConfig,
14
  XLMRobertaModel,
 
15
  )
16
 
17
 
 
208
  layer.current_task = task_idx
209
 
210
 
211
+ class XLMRobertaLoRA(XLMRobertaModel):
212
  def __init__(
213
  self,
214
  config: XLMRobertaFlashConfig,
 
215
  ):
216
  super().__init__(config)
217
 
 
 
 
 
 
218
  self._num_adaptations = len(config.lora_adaptations)
219
  self._rank = config.lora_rank
220
  self._dropout_p = config.lora_dropout_p
 
231
  # By default, we select the first LoRA
232
  self.current_task = 0
233
 
234
+
235
  @property
236
  def main_params_trainable(self):
237
  return self._main_params_trainable
 
267
  config = XLMRobertaFlashConfig.from_pretrained(
268
  pretrained_model_name_or_path, *model_args, **kwargs
269
  )
270
+
271
  if config.load_trained_adapters:
272
  return super().from_pretrained(
273
  pretrained_model_name_or_path, *model_args, **kwargs
274
  )
275
  else:
276
+ torch.set_default_dtype(torch.float16)
277
+ return cls(config)
 
 
278
 
279
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
280
  self.apply(
 
313
  def forward(self, *args, lora_adaptation: Union[None, int] = -1, **kwargs):
314
  if lora_adaptation is None or lora_adaptation >= 0:
315
  self.current_task = lora_adaptation
316
+ return super().forward(*args, **kwargs)
317
 
318
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
319
  for _, param in self.named_parameters(recurse=recurse):