jupyterjazz commited on
Commit
3703946
·
1 Parent(s): 851aaca

refactor: stuff

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. modeling_lora.py +37 -26
modeling_lora.py CHANGED
@@ -14,6 +14,9 @@ from transformers import PretrainedConfig
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
15
 
16
 
 
 
 
17
  def initialized_weights(
18
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
19
  ) -> torch.Tensor:
@@ -214,7 +217,17 @@ class XLMRobertaLoRA(XLMRobertaModel):
214
  ):
215
  super().__init__(config)
216
 
217
- self._num_adaptations = len(config.lora_adaptations)
 
 
 
 
 
 
 
 
 
 
218
  self._rank = config.lora_rank
219
  self._dropout_p = config.lora_dropout_p
220
  self._alpha = config.lora_alpha
@@ -294,14 +307,20 @@ class XLMRobertaLoRA(XLMRobertaModel):
294
  return self._task_idx
295
 
296
  @current_task.setter
297
- def current_task(self, task_idx: Union[None, int]):
298
  """Set the LoRA that is to be used.
299
  The LoRA is specified by `task_idx`, which may be an integer >= 0,
300
  indexing the available LoRAs. If it is None, no LoRA is used.
301
- :param task_idx: Which LoRA to use
302
  :return:
303
  """
304
- assert task_idx is None or 0 <= task_idx < self._num_adaptations
 
 
 
 
 
 
305
  if self._task_idx != task_idx:
306
  # In this case, we need to update the LoRAs everywhere
307
  self._task_idx = task_idx
@@ -309,9 +328,9 @@ class XLMRobertaLoRA(XLMRobertaModel):
309
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
310
  )
311
 
312
- def forward(self, *args, lora_adaptation: Union[None, int] = -1, **kwargs):
313
- if lora_adaptation is None or lora_adaptation >= 0:
314
- self.current_task = lora_adaptation
315
  return super().forward(*args, **kwargs)
316
 
317
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -331,35 +350,27 @@ class XLMRobertaLoRA(XLMRobertaModel):
331
  def encode(
332
  self,
333
  *args,
334
- task: Optional[str] = None,
335
  **kwargs,
336
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
337
  """
338
  Computes sentence embeddings
339
 
340
- task(`str`, *optional*, defaults to None):
341
- Specifies the task for which the encoding is intended. This
342
- controls the use of specialized LoRA adapters that are tuned for specific tasks.
343
- If provided, the corresponding LoRA adapter is enabled, enhancing the model's
344
- performance for that task. If `None` or not provided, LoRA is disabled, and the
345
- model uses its original, general-purpose weights.
 
346
  """
347
- lora_adapter_num = None
348
- if self.config.lora_adaptations:
349
- if task:
350
- if task in self.config.lora_adaptations:
351
- lora_adapter_num = self.config.lora_adaptations.index(task)
352
- else:
353
- raise ValueError(
354
- f"Unsupported task '{task}'. "
355
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
- )
357
- else:
358
  warnings.warn(
359
  f"Task-specific embeddings are disabled. To enable, specify the `task` "
360
  f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
361
  category=UserWarning,
362
  )
363
- self.current_task = lora_adapter_num
364
 
365
  return super().encode(*args, **kwargs)
 
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
15
 
16
 
17
+ LORA_NO_UPDATE = '__lora_no_update__'
18
+
19
+
20
  def initialized_weights(
21
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
22
  ) -> torch.Tensor:
 
217
  ):
218
  super().__init__(config)
219
 
220
+ self._lora_adaptations = config.lora_adaptations
221
+ if (
222
+ not isinstance(self._lora_adaptations, list)
223
+ or len(self._lora_adaptations) < 1
224
+ ):
225
+ raise ValueError(
226
+ f'`lora_adaptations` must be a list and contain at least one element'
227
+ )
228
+ self._adaptation_map = {
229
+ name: idx for idx, name in enumerate(self._lora_adaptations)
230
+ }
231
  self._rank = config.lora_rank
232
  self._dropout_p = config.lora_dropout_p
233
  self._alpha = config.lora_alpha
 
307
  return self._task_idx
308
 
309
  @current_task.setter
310
+ def current_task(self, task_name: Union[None, str]):
311
  """Set the LoRA that is to be used.
312
  The LoRA is specified by `task_idx`, which may be an integer >= 0,
313
  indexing the available LoRAs. If it is None, no LoRA is used.
314
+ :param task_name: Which LoRA to use
315
  :return:
316
  """
317
+ if task_name and task_name not in self._lora_adaptations:
318
+ raise ValueError(
319
+ f"Unsupported task '{task_name}'. "
320
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
321
+ f"Alternatively, set `task` to `None` if you want to disable LoRA."
322
+ )
323
+ task_idx = self._adaptation_map[task_name] if task_name else None
324
  if self._task_idx != task_idx:
325
  # In this case, we need to update the LoRAs everywhere
326
  self._task_idx = task_idx
 
328
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
329
  )
330
 
331
+ def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
332
+ if task != LORA_NO_UPDATE:
333
+ self.current_task = task
334
  return super().forward(*args, **kwargs)
335
 
336
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
350
  def encode(
351
  self,
352
  *args,
353
+ task: Union[str, None] = LORA_NO_UPDATE,
354
  **kwargs,
355
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
356
  """
357
  Computes sentence embeddings
358
 
359
+ task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
360
+ Specifies the task for which the encoding is intended. This parameter controls the
361
+ use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
362
+ to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
363
+ existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
364
+ adapters are disabled, and the model reverts to its original, general-purpose weights.
365
+ If `task` is set to a specific LoRA adaptation, that adaptation is activated.
366
  """
367
+ if task != LORA_NO_UPDATE:
368
+ if not task:
 
 
 
 
 
 
 
 
 
369
  warnings.warn(
370
  f"Task-specific embeddings are disabled. To enable, specify the `task` "
371
  f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
372
  category=UserWarning,
373
  )
374
+ self.current_task = task
375
 
376
  return super().encode(*args, **kwargs)