File size: 3,965 Bytes
933ca80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from transformers import BertModel, BertConfig
import torch
import torch.nn as nn

class AttentionPool(nn.Module):
    """Attention-based pooling layer."""
    def __init__(self, hidden_size):
        super(AttentionPool, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
        nn.init.xavier_uniform_(self.attention_weights)  # https://pytorch.org/docs/stable/nn.init.html

    def forward(self, hidden_states):
        attention_scores = torch.matmul(hidden_states, self.attention_weights)
        attention_scores = torch.softmax(attention_scores, dim=1)
        pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
        return pooled_output

class GeneformerMultiTask(nn.Module):
    def __init__(self, pretrained_path, num_labels_list, dropout_rate=0.1, use_task_weights=False, task_weights=None, max_layers_to_freeze=0, use_attention_pooling=False):
        super(GeneformerMultiTask, self).__init__()
        self.config = BertConfig.from_pretrained(pretrained_path)
        self.bert = BertModel(self.config)
        self.num_labels_list = num_labels_list
        self.use_task_weights = use_task_weights
        self.dropout = nn.Dropout(dropout_rate)
        self.use_attention_pooling = use_attention_pooling

        if use_task_weights and (task_weights is None or len(task_weights) != len(num_labels_list)):
            raise ValueError("Task weights must be defined and match the number of tasks when 'use_task_weights' is True.")
        self.task_weights = task_weights if use_task_weights else [1.0] * len(num_labels_list)

        # Freeze the specified initial layers
        for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
            for param in layer.parameters():
                param.requires_grad = False

        self.attention_pool = AttentionPool(self.config.hidden_size) if use_attention_pooling else None

        self.classification_heads = nn.ModuleList([
            nn.Linear(self.config.hidden_size, num_labels) for num_labels in num_labels_list
        ])
        # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
        for head in self.classification_heads:
            nn.init.xavier_uniform_(head.weight)
            nn.init.zeros_(head.bias)

    def forward(self, input_ids, attention_mask, labels=None):
        try:
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        except Exception as e:
            raise RuntimeError(f"Error during BERT forward pass: {e}")

        sequence_output = outputs.last_hidden_state

        try:
            pooled_output = self.attention_pool(sequence_output) if self.use_attention_pooling else sequence_output[:, 0, :]
            pooled_output = self.dropout(pooled_output)
        except Exception as e:
            raise RuntimeError(f"Error during pooling and dropout: {e}")

        total_loss = 0
        logits = []
        losses = []

        for task_id, (head, num_labels) in enumerate(zip(self.classification_heads, self.num_labels_list)):
            try:
                task_logits = head(pooled_output)
            except Exception as e:
                raise RuntimeError(f"Error during forward pass of classification head {task_id}: {e}")

            logits.append(task_logits)

            if labels is not None:
                try:
                    loss_fct = nn.CrossEntropyLoss()
                    task_loss = loss_fct(task_logits.view(-1, num_labels), labels[task_id].view(-1))
                    if self.use_task_weights:
                        task_loss *= self.task_weights[task_id]
                    total_loss += task_loss
                    losses.append(task_loss.item())
                except Exception as e:
                    raise RuntimeError(f"Error during loss computation for task {task_id}: {e}")

        return total_loss, logits, losses if labels is not None else logits