Fix bias logic to enable QLoRA finetuning
#5
by
winglian
- opened
- modeling_jamba.py +10 -4
modeling_jamba.py
CHANGED
@@ -943,10 +943,16 @@ class JambaMambaMixer(nn.Module):
|
|
943 |
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
944 |
# linear layers, and requires to call the forward pass directly.
|
945 |
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
950 |
|
951 |
A = -torch.exp(self.A_log.float())
|
952 |
# 3.c perform the recurrence y β SSM(A, B, C)(x)
|
|
|
943 |
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
944 |
# linear layers, and requires to call the forward pass directly.
|
945 |
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
946 |
+
if hasattr(self.dt_proj, "bias")
|
947 |
+
dt_proj_bias = self.dt_proj.bias
|
948 |
+
self.dt_proj.bias = None
|
949 |
+
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
950 |
+
self.dt_proj.bias = dt_proj_bias
|
951 |
+
else:
|
952 |
+
dt_proj_bias = self.dt_proj.base_layer.bias
|
953 |
+
self.dt_proj.base_layer.bias = None
|
954 |
+
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
955 |
+
self.dt_proj.base_layer.bias = dt_proj_bias
|
956 |
|
957 |
A = -torch.exp(self.A_log.float())
|
958 |
# 3.c perform the recurrence y β SSM(A, B, C)(x)
|