File size: 3,965 Bytes
933ca80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from transformers import BertModel, BertConfig
import torch
import torch.nn as nn
class AttentionPool(nn.Module):
"""Attention-based pooling layer."""
def __init__(self, hidden_size):
super(AttentionPool, self).__init__()
self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
nn.init.xavier_uniform_(self.attention_weights) #
def forward(self, hidden_states):
attention_scores = torch.matmul(hidden_states, self.attention_weights)
attention_scores = torch.softmax(attention_scores, dim=1)
pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
return pooled_output
class GeneformerMultiTask(nn.Module):
def __init__(self, pretrained_path, num_labels_list, dropout_rate=0.1, use_task_weights=False, task_weights=None, max_layers_to_freeze=0, use_attention_pooling=False):
super(GeneformerMultiTask, self).__init__()
self.config = BertConfig.from_pretrained(pretrained_path)
self.bert = BertModel(self.config)
self.num_labels_list = num_labels_list
self.use_task_weights = use_task_weights
self.dropout = nn.Dropout(dropout_rate)
self.use_attention_pooling = use_attention_pooling
if use_task_weights and (task_weights is None or len(task_weights) != len(num_labels_list)):
raise ValueError("Task weights must be defined and match the number of tasks when 'use_task_weights' is True.")
self.task_weights = task_weights if use_task_weights else [1.0] * len(num_labels_list)
# Freeze the specified initial layers
for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
for param in layer.parameters():
param.requires_grad = False
self.attention_pool = AttentionPool(self.config.hidden_size) if use_attention_pooling else None
self.classification_heads = nn.ModuleList([
nn.Linear(self.config.hidden_size, num_labels) for num_labels in num_labels_list
# initialization of the classification heads:
for head in self.classification_heads:
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
except Exception as e:
raise RuntimeError(f"Error during BERT forward pass: {e}")
sequence_output = outputs.last_hidden_state
pooled_output = self.attention_pool(sequence_output) if self.use_attention_pooling else sequence_output[:, 0, :]
pooled_output = self.dropout(pooled_output)
except Exception as e:
raise RuntimeError(f"Error during pooling and dropout: {e}")
total_loss = 0
logits = []
losses = []
for task_id, (head, num_labels) in enumerate(zip(self.classification_heads, self.num_labels_list)):
task_logits = head(pooled_output)
except Exception as e:
raise RuntimeError(f"Error during forward pass of classification head {task_id}: {e}")
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
task_loss = loss_fct(task_logits.view(-1, num_labels), labels[task_id].view(-1))
if self.use_task_weights:
task_loss *= self.task_weights[task_id]
total_loss += task_loss
except Exception as e:
raise RuntimeError(f"Error during loss computation for task {task_id}: {e}")
return total_loss, logits, losses if labels is not None else logits