ArielUW commited on
Commit
c170e18
·
verified ·
1 Parent(s): eb58b6f

Add RNN model with attention

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_type": "RNN", "vocab_size": 250002, "hidden_size": 256, "output_size": 2, "cell_type": "RNN", "architecture": "SimpleRecurrentNetworkWithAttention"}
modeling.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+ import transformers
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torch.utils.data import DataLoader, TensorDataset
10
+
11
+ class custom_RNNCell(nn.Module):
12
+ def __init__(self, input_size: int, hidden_size: int, device='cpu'):
13
+
14
+ # Initialize a basic RNN cell with Xavier-initialized weights.
15
+ # :param input_size: Number of input features.
16
+ # :param hidden_size: Number of units in the hidden layer.
17
+ #:param device: Device to place the tensors on.
18
+
19
+ super(custom_RNNCell, self).__init__()
20
+ self.hidden_size = hidden_size
21
+ self.device = device
22
+
23
+ # Xavier initialization limits
24
+ fan_in_Wx = input_size
25
+ fan_out_Wx = hidden_size
26
+ limit_Wx = np.sqrt(6 / (fan_in_Wx + fan_out_Wx))
27
+
28
+ fan_in_Wh = hidden_size
29
+ fan_out_Wh = hidden_size
30
+ limit_Wh = np.sqrt(6 / (fan_in_Wh + fan_out_Wh))
31
+
32
+ # Convert weights to PyTorch Parameters
33
+ self.Wx = nn.Parameter(torch.empty(input_size, hidden_size, device=device))
34
+ self.Wh = nn.Parameter(torch.empty(hidden_size, hidden_size, device=device))
35
+ self.bh = nn.Parameter(torch.zeros(hidden_size, device=device))
36
+
37
+ # Initialize using Xavier uniform
38
+ nn.init.uniform_(self.Wx, -limit_Wx, limit_Wx)
39
+ nn.init.uniform_(self.Wh, -limit_Wh, limit_Wh)
40
+
41
+ def forward(self, input_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor:
42
+
43
+ # Forward pass for a basic RNN cell.
44
+ # :param input_t: Input at time step t (batch_size x input_size).
45
+ # :param h_prev: Hidden state from previous time step (batch_size x hidden_size).
46
+ # :return: Updated hidden state.
47
+
48
+ h_t = torch.tanh(torch.mm(input_t, self.Wx) + torch.mm(h_prev, self.Wh) + self.bh)
49
+ return h_t
50
+
51
+ class custom_GRUCell(nn.Module):
52
+ def __init__(self, input_size: int, hidden_size: int, device='cpu'):
53
+
54
+ # Initialize a GRU cell with Xavier-initialized weights.
55
+ # :param input_size: Number of input features.
56
+ # :param hidden_size: Number of units in the hidden layer.
57
+ # :param device: The device to run the computations on.
58
+
59
+ super(custom_GRUCell, self).__init__()
60
+ self.hidden_size = hidden_size
61
+ self.device = device
62
+
63
+ # Xavier initialization limits
64
+ fan_in = input_size + hidden_size
65
+ fan_out = hidden_size
66
+ limit = (6 / (fan_in + fan_out)) ** 0.5
67
+
68
+ # Weight matrices for update gate, reset gate, and candidate hidden state
69
+ self.Wz = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Update gate
70
+ self.Wr = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Reset gate
71
+ self.Wh = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Candidate hidden state
72
+
73
+ # Apply Xavier initialization
74
+ nn.init.uniform_(self.Wz, -limit, limit)
75
+ nn.init.uniform_(self.Wr, -limit, limit)
76
+ nn.init.uniform_(self.Wh, -limit, limit)
77
+
78
+ # Biases for each gate, initialized to zeros
79
+ self.bz = nn.Parameter(torch.zeros(hidden_size, device=device))
80
+ self.br = nn.Parameter(torch.zeros(hidden_size, device=device))
81
+ self.bh = nn.Parameter(torch.zeros(hidden_size, device=device))
82
+
83
+ def forward(self, input_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor:
84
+
85
+ # Forward pass for a single GRU cell.
86
+ # :param input_t: Input at time step t (batch_size x input_size).
87
+ # :param h_prev: Hidden state from the previous time step (batch_size x hidden_size).
88
+ # :return: Updated hidden state.
89
+
90
+ # Concatenate input and previous hidden state
91
+ concat = torch.cat((input_t, h_prev), dim=1)
92
+
93
+ # Update gate
94
+ z_t = torch.sigmoid(torch.matmul(concat, self.Wz) + self.bz)
95
+
96
+ # Reset gate
97
+ r_t = torch.sigmoid(torch.matmul(concat, self.Wr) + self.br)
98
+
99
+ # Candidate hidden state
100
+ concat_reset = torch.cat((input_t, r_t * h_prev), dim=1)
101
+ h_hat_t = torch.tanh(torch.matmul(concat_reset, self.Wh) + self.bh)
102
+
103
+ # Compute final hidden state
104
+ h_t = (1 - z_t) * h_prev + z_t * h_hat_t
105
+
106
+ return h_t
107
+
108
+ class custom_LSTMCell(nn.Module):
109
+ def __init__(self, input_size: int, hidden_size: int, device='cpu'):
110
+ super(custom_LSTMCell, self).__init__()
111
+ self.hidden_size = hidden_size
112
+ self.device = device
113
+
114
+ # Initialize LSTM weights and biases using Xavier initialization
115
+ self.Wf = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Forget gate (W_f)
116
+ self.Wi = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Input gate (W_i)
117
+ self.Wc = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Candidate cell state (W_c~)
118
+ self.Wo = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Output gate (W_o)
119
+
120
+ # Apply Xavier initialization
121
+ nn.init.xavier_uniform_(self.Wf)
122
+ nn.init.xavier_uniform_(self.Wi)
123
+ nn.init.xavier_uniform_(self.Wc)
124
+ nn.init.xavier_uniform_(self.Wo)
125
+
126
+ # Initialize biases
127
+ self.bf = nn.Parameter(torch.zeros(hidden_size, device=device)) # Forget gate bias (b_f)
128
+ self.bi = nn.Parameter(torch.zeros(hidden_size, device=device)) # Input gate bias (b_i)
129
+ self.bc = nn.Parameter(torch.zeros(hidden_size, device=device)) # Candidate state bias (b_c~)
130
+ self.bo = nn.Parameter(torch.zeros(hidden_size, device=device)) # Output gate bias (b_o)
131
+
132
+ # Initialize forget gate bias to positive value to help with training
133
+ nn.init.constant_(self.bf, 1.0)
134
+
135
+ def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor) -> tuple:
136
+
137
+ # Forward pass for a single LSTM cell.
138
+ # :param x_t: Input at the current time step (batch_size x input_size).
139
+ # :param h_prev: Previous hidden state (batch_size x hidden_size).
140
+ # :param c_prev: Previous cell state (batch_size x hidden_size).
141
+ # :return: Tuple of new hidden state (h_t) and new cell state (c_t).
142
+
143
+ # Concatenate input and previous hidden state (x_t and h_{t-1})
144
+ concat = torch.cat((x_t, h_prev), dim=1)
145
+
146
+ # Forget gate: decides what to remove from the cell state
147
+ f_t = torch.sigmoid(torch.matmul(concat, self.Wf) + self.bf) # Forget gate (σ) -> f_t
148
+
149
+ # Input gate: decides what to add to the cell state
150
+ i_t = torch.sigmoid(torch.matmul(concat, self.Wi) + self.bi) # Input gate (σ) -> i_t
151
+
152
+ # Candidate cell state: new information to potentially add to the cell state
153
+ c_hat_t = torch.tanh(torch.matmul(concat, self.Wc) + self.bc) # Candidate cell state (tanh) -> c~_t
154
+
155
+ # Update cell state: new cell state (c_t) based on previous state and gates
156
+ c_t = f_t * c_prev + i_t * c_hat_t # Cell state update -> c_t
157
+
158
+ # Output gate: decides what the next hidden state should be
159
+ o_t = torch.sigmoid(torch.matmul(concat, self.Wo) + self.bo) # Output gate (σ) -> o_t
160
+
161
+ # New hidden state (h_t): based on cell state and output gate
162
+ h_t = o_t * torch.tanh(c_t) # Hidden state update -> h_t
163
+
164
+ # Return the new hidden state (h_t) and cell state (c_t)
165
+ return h_t, c_t
166
+
167
+ class RecurrentLayer(nn.Module):
168
+ def __init__(self, input_size: int, hidden_size: int, cell_type: str = 'RNN', device='cpu'):
169
+ super(RecurrentLayer, self).__init__()
170
+ self.hidden_size = hidden_size
171
+ self.device = device
172
+ self.cell_type = cell_type
173
+
174
+ # Initialize the appropriate cell type
175
+ if cell_type == 'RNN':
176
+ self.cell = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
177
+ elif cell_type == 'custom_RNN':
178
+ self.cell = custom_RNNCell(input_size, hidden_size)
179
+ elif cell_type == 'GRU':
180
+ self.cell = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True)
181
+ elif cell_type == 'custom_GRU':
182
+ self.cell = custom_GRUCell(input_size, hidden_size, device)
183
+ elif cell_type == 'LSTM':
184
+ self.cell = nn.LSTMCell(input_size, hidden_size)
185
+ elif cell_type == 'custom_LSTM':
186
+ self.cell = custom_LSTMCell(input_size, hidden_size, device)
187
+ else:
188
+ raise ValueError("Unsupported cell type")
189
+
190
+ def forward(self, inputs: torch.Tensor) -> tuple:
191
+
192
+ # Forward pass through the recurrent layer for a sequence of inputs.
193
+ # Returns a tuple of (output, last_hidden_state) to match PyTorch's interface.
194
+
195
+ batch_size, seq_len, _ = inputs.shape
196
+
197
+ # Initialize hidden states
198
+ h_forward = torch.zeros(batch_size, self.hidden_size, device=self.device)
199
+ h_backward = torch.zeros(batch_size, self.hidden_size, device=self.device)
200
+
201
+ if self.cell_type == 'custom_LSTM':
202
+ c_forward = torch.zeros(batch_size, self.hidden_size, device=self.device)
203
+ c_backward = torch.zeros(batch_size, self.hidden_size, device=self.device)
204
+
205
+ # Lists to store outputs for both directions
206
+ forward_outputs = []
207
+ backward_outputs = []
208
+
209
+ # Forward pass
210
+ h = h_forward
211
+ c = c_forward if self.cell_type == 'custom_LSTM' else None
212
+ for t in range(seq_len):
213
+ if self.cell_type == 'custom_LSTM':
214
+ h, c = self.cell(inputs[:, t], h, c)
215
+ else:
216
+ h = self.cell(inputs[:, t], h)
217
+ forward_outputs.append(h)
218
+
219
+ # Backward pass
220
+ h = h_backward
221
+ c = c_backward if self.cell_type == 'custom_LSTM' else None
222
+ for t in range(seq_len - 1, -1, -1):
223
+ if self.cell_type == 'custom_LSTM':
224
+ h, c = self.cell(inputs[:, t], h, c)
225
+ else:
226
+ h = self.cell(inputs[:, t], h)
227
+ backward_outputs.insert(0, h)
228
+
229
+ # Stack and concatenate outputs
230
+ forward_output = torch.stack(forward_outputs, dim=1) # [batch_size, seq_len, hidden_size]
231
+ backward_output = torch.stack(backward_outputs, dim=1) # [batch_size, seq_len, hidden_size]
232
+ output = torch.cat((forward_output, backward_output), dim=2) # [batch_size, seq_len, 2*hidden_size]
233
+
234
+ # Create final hidden state (concatenated forward and backward)
235
+ final_hidden = torch.stack([forward_outputs[-1], backward_outputs[-1]], dim=0) # [2, batch_size, hidden_size]
236
+
237
+ return output, final_hidden
238
+
239
+ class Attention(nn.Module):
240
+ def __init__(self, hidden_size):
241
+ super(Attention, self).__init__()
242
+ self.W1 = nn.Linear(hidden_size, hidden_size)
243
+ self.W2 = nn.Linear(hidden_size, hidden_size)
244
+ self.v = nn.Linear(hidden_size, 1, bias=False)
245
+
246
+ def forward(self, hidden, encoder_outputs):
247
+ # hidden: [batch_size, hidden_size]
248
+ # encoder_outputs: [batch_size, sequence_len, hidden_size]
249
+ sequence_len = encoder_outputs.shape[1]
250
+ hidden = hidden.unsqueeze(1).repeat(1, sequence_len, 1)
251
+
252
+ energy = torch.tanh(self.W1(encoder_outputs) + self.W2(hidden))
253
+ attention = self.v(energy).squeeze(2) # [batch_size, sequence_len]
254
+ attention_weights = torch.softmax(attention, dim=1)
255
+
256
+ # Apply attention weights to encoder outputs to get context vector
257
+ context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
258
+ return context, attention_weights
259
+
260
+ class SimpleRecurrentNetworkWithAttention(nn.Module):
261
+ def __init__(self, input_size, hidden_size, output_size, cell_type='RNN'):
262
+ super(SimpleRecurrentNetworkWithAttention, self).__init__()
263
+
264
+ self.embedding = nn.Embedding(input_size, hidden_size)
265
+ self.attention = Attention(hidden_size * 2) # Use hidden_size * 2 for bidirectional LSTM
266
+ self.cell_type = cell_type
267
+
268
+ if cell_type == 'RNN':
269
+ self.cell = nn.RNN(hidden_size, hidden_size, batch_first=True, bidirectional=True)
270
+ elif cell_type == 'custom_RNN':
271
+ self.cell = RecurrentLayer(hidden_size, hidden_size, cell_type="custom_RNN") #custom_RNNCell(input_size, hidden_size)
272
+ elif cell_type == 'GRU':
273
+ self.cell = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
274
+ elif cell_type == 'custom_GRU':
275
+ self.cell = RecurrentLayer(hidden_size, hidden_size, cell_type="custom_GRU")
276
+ elif cell_type == 'LSTM':
277
+ self.cell = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)
278
+ elif cell_type == 'custom_LSTM':
279
+ self.cell = RecurrentLayer(hidden_size, hidden_size, cell_type="custom_LSTM")
280
+ else:
281
+ raise ValueError("Unsupported cell type. Choose from 'RNN', 'custom_RNN', 'GRU', 'custom_GRU', 'LSTM' or 'custom_LSTM'.")
282
+
283
+ self.fc = nn.Linear(hidden_size * 2, output_size) # hidden_size * 2 for bidirectional
284
+
285
+ def forward(self, inputs):
286
+ embedded = self.embedding(inputs)
287
+ rnn_output, hidden = self.cell(embedded)
288
+
289
+ if isinstance(hidden, tuple): # LSTM returns (hidden, cell_state)
290
+ hidden = hidden[0]
291
+
292
+ # Since it's bidirectional, get the last layer's forward and backward hidden states
293
+ hidden = torch.cat((hidden[-2], hidden[-1]), dim=1) # Concatenate forward and backward hidden states
294
+
295
+ # Apply attention to the concatenated hidden state
296
+ context, attention_weights = self.attention(hidden, rnn_output)
297
+
298
+ # Pass the context vector to the fully connected layer
299
+ output = self.fc(context)
300
+
301
+ return output, attention_weights
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d75a21bee6da43507a4799f32ffc505bccaca9c1d9f514341ec53c5d2a95e9e
3
+ size 259167710
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:affcfb1f45c4b14a70a6589c3d153b430ed4309e5a6613a88dab64d5a923a5d6
3
+ size 17082925
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "clean_up_tokenization_spaces": true,
4
+ "cls_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "mask_token": {
7
+ "__type": "AddedToken",
8
+ "content": "<mask>",
9
+ "lstrip": true,
10
+ "normalized": true,
11
+ "rstrip": false,
12
+ "single_word": false
13
+ },
14
+ "model_max_length": 512,
15
+ "pad_token": "<pad>",
16
+ "sep_token": "</s>",
17
+ "tokenizer_class": "XLMRobertaTokenizer",
18
+ "unk_token": "<unk>"
19
+ }