import time from datasets import load_dataset import numpy as np import transformers import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset class custom_RNNCell(nn.Module): def __init__(self, input_size: int, hidden_size: int, device='cpu'): # Initialize a basic RNN cell with Xavier-initialized weights. # :param input_size: Number of input features. # :param hidden_size: Number of units in the hidden layer. #:param device: Device to place the tensors on. super(custom_RNNCell, self).__init__() self.hidden_size = hidden_size self.device = device # Xavier initialization limits fan_in_Wx = input_size fan_out_Wx = hidden_size limit_Wx = np.sqrt(6 / (fan_in_Wx + fan_out_Wx)) fan_in_Wh = hidden_size fan_out_Wh = hidden_size limit_Wh = np.sqrt(6 / (fan_in_Wh + fan_out_Wh)) # Convert weights to PyTorch Parameters self.Wx = nn.Parameter(torch.empty(input_size, hidden_size, device=device)) self.Wh = nn.Parameter(torch.empty(hidden_size, hidden_size, device=device)) self.bh = nn.Parameter(torch.zeros(hidden_size, device=device)) # Initialize using Xavier uniform nn.init.uniform_(self.Wx, -limit_Wx, limit_Wx) nn.init.uniform_(self.Wh, -limit_Wh, limit_Wh) def forward(self, input_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor: # Forward pass for a basic RNN cell. # :param input_t: Input at time step t (batch_size x input_size). # :param h_prev: Hidden state from previous time step (batch_size x hidden_size). # :return: Updated hidden state. h_t = torch.tanh(torch.mm(input_t, self.Wx) + torch.mm(h_prev, self.Wh) + self.bh) return h_t class custom_GRUCell(nn.Module): def __init__(self, input_size: int, hidden_size: int, device='cpu'): # Initialize a GRU cell with Xavier-initialized weights. # :param input_size: Number of input features. # :param hidden_size: Number of units in the hidden layer. # :param device: The device to run the computations on. super(custom_GRUCell, self).__init__() self.hidden_size = hidden_size self.device = device # Xavier initialization limits fan_in = input_size + hidden_size fan_out = hidden_size limit = (6 / (fan_in + fan_out)) ** 0.5 # Weight matrices for update gate, reset gate, and candidate hidden state self.Wz = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Update gate self.Wr = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Reset gate self.Wh = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Candidate hidden state # Apply Xavier initialization nn.init.uniform_(self.Wz, -limit, limit) nn.init.uniform_(self.Wr, -limit, limit) nn.init.uniform_(self.Wh, -limit, limit) # Biases for each gate, initialized to zeros self.bz = nn.Parameter(torch.zeros(hidden_size, device=device)) self.br = nn.Parameter(torch.zeros(hidden_size, device=device)) self.bh = nn.Parameter(torch.zeros(hidden_size, device=device)) def forward(self, input_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor: # Forward pass for a single GRU cell. # :param input_t: Input at time step t (batch_size x input_size). # :param h_prev: Hidden state from the previous time step (batch_size x hidden_size). # :return: Updated hidden state. # Concatenate input and previous hidden state concat = torch.cat((input_t, h_prev), dim=1) # Update gate z_t = torch.sigmoid(torch.matmul(concat, self.Wz) + self.bz) # Reset gate r_t = torch.sigmoid(torch.matmul(concat, self.Wr) + self.br) # Candidate hidden state concat_reset = torch.cat((input_t, r_t * h_prev), dim=1) h_hat_t = torch.tanh(torch.matmul(concat_reset, self.Wh) + self.bh) # Compute final hidden state h_t = (1 - z_t) * h_prev + z_t * h_hat_t return h_t class custom_LSTMCell(nn.Module): def __init__(self, input_size: int, hidden_size: int, device='cpu'): super(custom_LSTMCell, self).__init__() self.hidden_size = hidden_size self.device = device # Initialize LSTM weights and biases using Xavier initialization self.Wf = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Forget gate (W_f) self.Wi = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Input gate (W_i) self.Wc = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Candidate cell state (W_c~) self.Wo = nn.Parameter(torch.empty(input_size + hidden_size, hidden_size, device=device)) # Output gate (W_o) # Apply Xavier initialization nn.init.xavier_uniform_(self.Wf) nn.init.xavier_uniform_(self.Wi) nn.init.xavier_uniform_(self.Wc) nn.init.xavier_uniform_(self.Wo) # Initialize biases self.bf = nn.Parameter(torch.zeros(hidden_size, device=device)) # Forget gate bias (b_f) self.bi = nn.Parameter(torch.zeros(hidden_size, device=device)) # Input gate bias (b_i) self.bc = nn.Parameter(torch.zeros(hidden_size, device=device)) # Candidate state bias (b_c~) self.bo = nn.Parameter(torch.zeros(hidden_size, device=device)) # Output gate bias (b_o) # Initialize forget gate bias to positive value to help with training nn.init.constant_(self.bf, 1.0) def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor) -> tuple: # Forward pass for a single LSTM cell. # :param x_t: Input at the current time step (batch_size x input_size). # :param h_prev: Previous hidden state (batch_size x hidden_size). # :param c_prev: Previous cell state (batch_size x hidden_size). # :return: Tuple of new hidden state (h_t) and new cell state (c_t). # Concatenate input and previous hidden state (x_t and h_{t-1}) concat = torch.cat((x_t, h_prev), dim=1) # Forget gate: decides what to remove from the cell state f_t = torch.sigmoid(torch.matmul(concat, self.Wf) + self.bf) # Forget gate (σ) -> f_t # Input gate: decides what to add to the cell state i_t = torch.sigmoid(torch.matmul(concat, self.Wi) + self.bi) # Input gate (σ) -> i_t # Candidate cell state: new information to potentially add to the cell state c_hat_t = torch.tanh(torch.matmul(concat, self.Wc) + self.bc) # Candidate cell state (tanh) -> c~_t # Update cell state: new cell state (c_t) based on previous state and gates c_t = f_t * c_prev + i_t * c_hat_t # Cell state update -> c_t # Output gate: decides what the next hidden state should be o_t = torch.sigmoid(torch.matmul(concat, self.Wo) + self.bo) # Output gate (σ) -> o_t # New hidden state (h_t): based on cell state and output gate h_t = o_t * torch.tanh(c_t) # Hidden state update -> h_t # Return the new hidden state (h_t) and cell state (c_t) return h_t, c_t class RecurrentLayer(nn.Module): def __init__(self, input_size: int, hidden_size: int, cell_type: str = 'RNN', device='cpu'): super(RecurrentLayer, self).__init__() self.hidden_size = hidden_size self.device = device self.cell_type = cell_type # Initialize the appropriate cell type if cell_type == 'RNN': self.cell = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True) elif cell_type == 'custom_RNN': self.cell = custom_RNNCell(input_size, hidden_size) elif cell_type == 'GRU': self.cell = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True) elif cell_type == 'custom_GRU': self.cell = custom_GRUCell(input_size, hidden_size, device) elif cell_type == 'LSTM': self.cell = nn.LSTMCell(input_size, hidden_size) elif cell_type == 'custom_LSTM': self.cell = custom_LSTMCell(input_size, hidden_size, device) else: raise ValueError("Unsupported cell type") def forward(self, inputs: torch.Tensor) -> tuple: # Forward pass through the recurrent layer for a sequence of inputs. # Returns a tuple of (output, last_hidden_state) to match PyTorch's interface. batch_size, seq_len, _ = inputs.shape # Initialize hidden states h_forward = torch.zeros(batch_size, self.hidden_size, device=self.device) h_backward = torch.zeros(batch_size, self.hidden_size, device=self.device) if self.cell_type == 'custom_LSTM': c_forward = torch.zeros(batch_size, self.hidden_size, device=self.device) c_backward = torch.zeros(batch_size, self.hidden_size, device=self.device) # Lists to store outputs for both directions forward_outputs = [] backward_outputs = [] # Forward pass h = h_forward c = c_forward if self.cell_type == 'custom_LSTM' else None for t in range(seq_len): if self.cell_type == 'custom_LSTM': h, c = self.cell(inputs[:, t], h, c) else: h = self.cell(inputs[:, t], h) forward_outputs.append(h) # Backward pass h = h_backward c = c_backward if self.cell_type == 'custom_LSTM' else None for t in range(seq_len - 1, -1, -1): if self.cell_type == 'custom_LSTM': h, c = self.cell(inputs[:, t], h, c) else: h = self.cell(inputs[:, t], h) backward_outputs.insert(0, h) # Stack and concatenate outputs forward_output = torch.stack(forward_outputs, dim=1) # [batch_size, seq_len, hidden_size] backward_output = torch.stack(backward_outputs, dim=1) # [batch_size, seq_len, hidden_size] output = torch.cat((forward_output, backward_output), dim=2) # [batch_size, seq_len, 2*hidden_size] # Create final hidden state (concatenated forward and backward) final_hidden = torch.stack([forward_outputs[-1], backward_outputs[-1]], dim=0) # [2, batch_size, hidden_size] return output, final_hidden class Attention(nn.Module): def __init__(self, hidden_size): super(Attention, self).__init__() self.W1 = nn.Linear(hidden_size, hidden_size) self.W2 = nn.Linear(hidden_size, hidden_size) self.v = nn.Linear(hidden_size, 1, bias=False) def forward(self, hidden, encoder_outputs): # hidden: [batch_size, hidden_size] # encoder_outputs: [batch_size, sequence_len, hidden_size] sequence_len = encoder_outputs.shape[1] hidden = hidden.unsqueeze(1).repeat(1, sequence_len, 1) energy = torch.tanh(self.W1(encoder_outputs) + self.W2(hidden)) attention = self.v(energy).squeeze(2) # [batch_size, sequence_len] attention_weights = torch.softmax(attention, dim=1) # Apply attention weights to encoder outputs to get context vector context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1) return context, attention_weights class SimpleRecurrentNetworkWithAttention(nn.Module): def __init__(self, input_size, hidden_size, output_size, cell_type='RNN'): super(SimpleRecurrentNetworkWithAttention, self).__init__() self.embedding = nn.Embedding(input_size, hidden_size) self.attention = Attention(hidden_size * 2) # Use hidden_size * 2 for bidirectional LSTM self.cell_type = cell_type if cell_type == 'RNN': self.cell = nn.RNN(hidden_size, hidden_size, batch_first=True, bidirectional=True) elif cell_type == 'custom_RNN': self.cell = RecurrentLayer(hidden_size, hidden_size, cell_type="custom_RNN") #custom_RNNCell(input_size, hidden_size) elif cell_type == 'GRU': self.cell = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) elif cell_type == 'custom_GRU': self.cell = RecurrentLayer(hidden_size, hidden_size, cell_type="custom_GRU") elif cell_type == 'LSTM': self.cell = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True) elif cell_type == 'custom_LSTM': self.cell = RecurrentLayer(hidden_size, hidden_size, cell_type="custom_LSTM") else: raise ValueError("Unsupported cell type. Choose from 'RNN', 'custom_RNN', 'GRU', 'custom_GRU', 'LSTM' or 'custom_LSTM'.") self.fc = nn.Linear(hidden_size * 2, output_size) # hidden_size * 2 for bidirectional def forward(self, inputs): embedded = self.embedding(inputs) rnn_output, hidden = self.cell(embedded) if isinstance(hidden, tuple): # LSTM returns (hidden, cell_state) hidden = hidden[0] # Since it's bidirectional, get the last layer's forward and backward hidden states hidden = torch.cat((hidden[-2], hidden[-1]), dim=1) # Concatenate forward and backward hidden states # Apply attention to the concatenated hidden state context, attention_weights = self.attention(hidden, rnn_output) # Pass the context vector to the fully connected layer output = self.fc(context) return output, attention_weights