Fix bias logic to enable QLoRA finetuning

#5
by winglian - opened
Files changed (1) hide show
  1. modeling_jamba.py +10 -4
modeling_jamba.py CHANGED
@@ -943,10 +943,16 @@ class JambaMambaMixer(nn.Module):
943
  # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
944
  # linear layers, and requires to call the forward pass directly.
945
  # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
946
- dt_proj_bias = self.dt_proj.bias
947
- self.dt_proj.bias = None
948
- discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
949
- self.dt_proj.bias = dt_proj_bias
 
 
 
 
 
 
950
 
951
  A = -torch.exp(self.A_log.float())
952
  # 3.c perform the recurrence y ← SSM(A, B, C)(x)
 
943
  # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
944
  # linear layers, and requires to call the forward pass directly.
945
  # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
946
+ if hasattr(self.dt_proj, "bias")
947
+ dt_proj_bias = self.dt_proj.bias
948
+ self.dt_proj.bias = None
949
+ discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
950
+ self.dt_proj.bias = dt_proj_bias
951
+ else:
952
+ dt_proj_bias = self.dt_proj.base_layer.bias
953
+ self.dt_proj.base_layer.bias = None
954
+ discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
955
+ self.dt_proj.base_layer.bias = dt_proj_bias
956
 
957
  A = -torch.exp(self.A_log.float())
958
  # 3.c perform the recurrence y ← SSM(A, B, C)(x)