FZH1996 commited on
Commit
fe45bc3
·
1 Parent(s): cb2ad99

upload fed-lora

Browse files
Files changed (4) hide show
  1. loralib/__init__.py +4 -0
  2. loralib/layers.py +319 -0
  3. loralib/utils.py +49 -0
  4. setup.py +22 -0
loralib/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name = "lora"
2
+
3
+ from .layers import *
4
+ from .utils import *
loralib/layers.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import math
10
+ from typing import Optional, List
11
+
12
+ class LoRALayer():
13
+ def __init__(
14
+ self,
15
+ r: int,
16
+ lora_alpha: int,
17
+ lora_dropout: float,
18
+ merge_weights: bool,
19
+ ):
20
+ self.r = r
21
+ self.lora_alpha = lora_alpha
22
+ # Optional dropout
23
+ if lora_dropout > 0.:
24
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
25
+ else:
26
+ self.lora_dropout = lambda x: x
27
+ # Mark the weight as unmerged
28
+ self.merged = False
29
+ self.merge_weights = merge_weights
30
+
31
+
32
+ class Embedding(nn.Embedding, LoRALayer):
33
+ # LoRA implemented in a dense layer
34
+ def __init__(
35
+ self,
36
+ num_embeddings: int,
37
+ embedding_dim: int,
38
+ r: int = 0,
39
+ lora_alpha: int = 1,
40
+ merge_weights: bool = True,
41
+ **kwargs
42
+ ):
43
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
44
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
45
+ merge_weights=merge_weights)
46
+ # Actual trainable parameters
47
+ if r > 0:
48
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
49
+ self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
50
+ self.scaling = self.lora_alpha / self.r
51
+ # Freezing the pre-trained weight matrix
52
+ self.weight.requires_grad = False
53
+ self.reset_parameters()
54
+
55
+ def reset_parameters(self):
56
+ nn.Embedding.reset_parameters(self)
57
+ if hasattr(self, 'lora_A'):
58
+ # initialize A the same way as the default for nn.Linear and B to zero
59
+ nn.init.zeros_(self.lora_A)
60
+ nn.init.normal_(self.lora_B)
61
+
62
+ def train(self, mode: bool = True):
63
+ nn.Embedding.train(self, mode)
64
+ if mode:
65
+ if self.merge_weights and self.merged:
66
+ # Make sure that the weights are not merged
67
+ if self.r > 0:
68
+ self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
69
+ self.merged = False
70
+ else:
71
+ if self.merge_weights and not self.merged:
72
+ # Merge the weights and mark it
73
+ if self.r > 0:
74
+ self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
75
+ self.merged = True
76
+
77
+ def forward(self, x: torch.Tensor):
78
+ if self.r > 0 and not self.merged:
79
+ result = nn.Embedding.forward(self, x)
80
+ if self.r > 0:
81
+ after_A = F.embedding(
82
+ x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
83
+ self.norm_type, self.scale_grad_by_freq, self.sparse
84
+ )
85
+ result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
86
+ return result
87
+ else:
88
+ return nn.Embedding.forward(self, x)
89
+
90
+
91
+ class Linear(nn.Linear, LoRALayer):
92
+ # LoRA implemented in a dense layer
93
+ def __init__(
94
+ self,
95
+ in_features: int,
96
+ out_features: int,
97
+ r: int = 0,
98
+ lora_alpha: int = 1,
99
+ lora_dropout: float = 0.,
100
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
101
+ merge_weights: bool = True,
102
+ **kwargs
103
+ ):
104
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
105
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
106
+ merge_weights=merge_weights)
107
+
108
+ self.fan_in_fan_out = fan_in_fan_out
109
+ # Actual trainable parameters
110
+ if r > 0:
111
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
112
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
113
+ self.scaling = self.lora_alpha / self.r
114
+ # Freezing the pre-trained weight matrix
115
+ self.weight.requires_grad = False
116
+ self.reset_parameters()
117
+ if fan_in_fan_out:
118
+ self.weight.data = self.weight.data.transpose(0, 1)
119
+
120
+ def reset_parameters(self):
121
+ nn.Linear.reset_parameters(self)
122
+ if hasattr(self, 'lora_A'):
123
+ # initialize A the same way as the default for nn.Linear and B to zero
124
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
125
+ nn.init.zeros_(self.lora_B)
126
+
127
+ def train(self, mode: bool = True):
128
+ def T(w):
129
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
130
+ nn.Linear.train(self, mode)
131
+ if mode:
132
+ if self.merge_weights and self.merged:
133
+ # Make sure that the weights are not merged
134
+ if self.r > 0:
135
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
136
+ self.merged = False
137
+ else:
138
+ if self.merge_weights and not self.merged:
139
+ # Merge the weights and mark it
140
+ if self.r > 0:
141
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
142
+ self.merged = True
143
+
144
+ def forward(self, x: torch.Tensor):
145
+ def T(w):
146
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
147
+ if self.r > 0 and not self.merged:
148
+ result = F.linear(x, T(self.weight), bias=self.bias)
149
+ if self.r > 0:
150
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
151
+ return result
152
+ else:
153
+ return F.linear(x, T(self.weight), bias=self.bias)
154
+
155
+
156
+ class MergedLinear(nn.Linear, LoRALayer):
157
+ # LoRA implemented in a dense layer
158
+ def __init__(
159
+ self,
160
+ in_features: int,
161
+ out_features: int,
162
+ r: int = 0,
163
+ lora_alpha: int = 1,
164
+ lora_dropout: float = 0.,
165
+ enable_lora: List[bool] = [False],
166
+ fan_in_fan_out: bool = False,
167
+ merge_weights: bool = True,
168
+ **kwargs
169
+ ):
170
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
171
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
172
+ merge_weights=merge_weights)
173
+ assert out_features % len(enable_lora) == 0, \
174
+ 'The length of enable_lora must divide out_features'
175
+ self.enable_lora = enable_lora
176
+ self.fan_in_fan_out = fan_in_fan_out
177
+ # Actual trainable parameters
178
+ if r > 0 and any(enable_lora):
179
+ self.lora_A = nn.Parameter(
180
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
181
+ self.lora_B = nn.Parameter(
182
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
183
+ ) # weights for Conv1D with groups=sum(enable_lora)
184
+ self.scaling = self.lora_alpha / self.r
185
+ # Freezing the pre-trained weight matrix
186
+ self.weight.requires_grad = False
187
+ # Compute the indices
188
+ self.lora_ind = self.weight.new_zeros(
189
+ (out_features, ), dtype=torch.bool
190
+ ).view(len(enable_lora), -1)
191
+ self.lora_ind[enable_lora, :] = True
192
+ self.lora_ind = self.lora_ind.view(-1)
193
+ self.reset_parameters()
194
+ if fan_in_fan_out:
195
+ self.weight.data = self.weight.data.transpose(0, 1)
196
+
197
+ def reset_parameters(self):
198
+ nn.Linear.reset_parameters(self)
199
+ if hasattr(self, 'lora_A'):
200
+ # initialize A the same way as the default for nn.Linear and B to zero
201
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
202
+ nn.init.zeros_(self.lora_B)
203
+
204
+ def zero_pad(self, x):
205
+ result = x.new_zeros((*x.shape[:-1], self.out_features))
206
+ result = result.view(-1, self.out_features)
207
+ result[:, self.lora_ind] = x.reshape(
208
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
209
+ )
210
+ return result.view((*x.shape[:-1], self.out_features))
211
+
212
+ def train(self, mode: bool = True):
213
+ def T(w):
214
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
215
+ nn.Linear.train(self, mode)
216
+ print(f"lora.train, scaling = {self.scaling}, mode = {mode}, merge_weights = {self.merge_weights}, merged = {self.merged}")
217
+ if mode:
218
+ if self.merge_weights and self.merged:
219
+ # Make sure that the weights are not merged
220
+ if self.r > 0 and any(self.enable_lora):
221
+ delta_w = F.conv1d(
222
+ self.lora_A.data.unsqueeze(0),
223
+ self.lora_B.data.unsqueeze(-1),
224
+ groups=sum(self.enable_lora)
225
+ ).squeeze(0)
226
+ self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
227
+ self.merged = False
228
+ else:
229
+ if self.merge_weights and not self.merged:
230
+ # Merge the weights and mark it
231
+ if self.r > 0 and any(self.enable_lora):
232
+ delta_w = F.conv1d(
233
+ self.lora_A.data.unsqueeze(0),
234
+ self.lora_B.data.unsqueeze(-1),
235
+ groups=sum(self.enable_lora)
236
+ ).squeeze(0)
237
+ self.weight.data += self.zero_pad(T(delta_w * self.scaling))
238
+ self.merged = True
239
+
240
+ def forward(self, x: torch.Tensor):
241
+ def T(w):
242
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
243
+ if self.merged:
244
+ return F.linear(x, T(self.weight), bias=self.bias)
245
+ else:
246
+ result = F.linear(x, T(self.weight), bias=self.bias)
247
+ if self.r > 0:
248
+ after_A = F.linear(self.lora_dropout(x), self.lora_A)
249
+ after_B = F.conv1d(
250
+ after_A.transpose(-2, -1),
251
+ self.lora_B.unsqueeze(-1),
252
+ groups=sum(self.enable_lora)
253
+ ).transpose(-2, -1)
254
+ result += self.zero_pad(after_B) * self.scaling
255
+ return result
256
+
257
+
258
+ class ConvLoRA(nn.Module, LoRALayer):
259
+ def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
260
+ super(ConvLoRA, self).__init__()
261
+ self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
262
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
263
+ assert isinstance(kernel_size, int)
264
+ # Actual trainable parameters
265
+ if r > 0:
266
+ self.lora_A = nn.Parameter(
267
+ self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
268
+ )
269
+ self.lora_B = nn.Parameter(
270
+ self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
271
+ )
272
+ self.scaling = self.lora_alpha / self.r
273
+ # Freezing the pre-trained weight matrix
274
+ self.conv.weight.requires_grad = False
275
+ self.reset_parameters()
276
+ self.merged = False
277
+
278
+ def reset_parameters(self):
279
+ self.conv.reset_parameters()
280
+ if hasattr(self, 'lora_A'):
281
+ # initialize A the same way as the default for nn.Linear and B to zero
282
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
283
+ nn.init.zeros_(self.lora_B)
284
+
285
+ def train(self, mode=True):
286
+ super(ConvLoRA, self).train(mode)
287
+ if mode:
288
+ if self.merge_weights and self.merged:
289
+ # Make sure that the weights are not merged
290
+ self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
291
+ self.merged = False
292
+ else:
293
+ if self.merge_weights and not self.merged:
294
+ # Merge the weights and mark it
295
+ self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
296
+ self.merged = True
297
+
298
+ def forward(self, x):
299
+ if self.r > 0 and not self.merged:
300
+ return self.conv._conv_forward(
301
+ x,
302
+ self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
303
+ self.conv.bias
304
+ )
305
+ return self.conv(x)
306
+
307
+ class Conv2d(ConvLoRA):
308
+ def __init__(self, *args, **kwargs):
309
+ super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
310
+
311
+ class Conv1d(ConvLoRA):
312
+ def __init__(self, *args, **kwargs):
313
+ super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
314
+
315
+ # Can Extend to other ones like this
316
+
317
+ class Conv3d(ConvLoRA):
318
+ def __init__(self, *args, **kwargs):
319
+ super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
loralib/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from typing import Dict
9
+
10
+ from .layers import LoRALayer
11
+
12
+
13
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
14
+ for n, p in model.named_parameters():
15
+ if 'lora_' not in n:
16
+ p.requires_grad = False
17
+ if bias == 'none':
18
+ return
19
+ elif bias == 'all':
20
+ for n, p in model.named_parameters():
21
+ if 'bias' in n:
22
+ p.requires_grad = True
23
+ elif bias == 'lora_only':
24
+ for m in model.modules():
25
+ if isinstance(m, LoRALayer) and \
26
+ hasattr(m, 'bias') and \
27
+ m.bias is not None:
28
+ m.bias.requires_grad = True
29
+ else:
30
+ raise NotImplementedError
31
+
32
+
33
+ def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
34
+ my_state_dict = model.state_dict()
35
+ if bias == 'none':
36
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
37
+ elif bias == 'all':
38
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
39
+ elif bias == 'lora_only':
40
+ to_return = {}
41
+ for k in my_state_dict:
42
+ if 'lora_' in k:
43
+ to_return[k] = my_state_dict[k]
44
+ bias_name = k.split('lora_')[0]+'bias'
45
+ if bias_name in my_state_dict:
46
+ to_return[bias_name] = my_state_dict[bias_name]
47
+ return to_return
48
+ else:
49
+ raise NotImplementedError
setup.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ with open("README.md", "r", encoding="utf-8") as fh:
4
+ long_description = fh.read()
5
+
6
+ setuptools.setup(
7
+ name="loralib",
8
+ version="0.1.0",
9
+ author="Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen",
10
+ author_email="[email protected]",
11
+ description="PyTorch implementation of low-rank adaptation (LoRA), a parameter-efficient approach to adapt a large pre-trained deep learning model which obtains performance on-par with full fine-tuning.",
12
+ long_description=long_description,
13
+ long_description_content_type="text/markdown",
14
+ url="https://github.com/microsoft/LoRA",
15
+ packages=setuptools.find_packages(),
16
+ classifiers=[
17
+ "Programming Language :: Python :: 3",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Operating System :: OS Independent",
20
+ ],
21
+ python_requires='>=3.6',
22
+ )