File size: 8,465 Bytes
1e712af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
import torch.nn as nn
from transformers import AutoConfig, XLMRobertaXLModel
class SchemaItemClassifier(nn.Module):
def __init__(self, model_name_or_path, mode):
super(SchemaItemClassifier, self).__init__()
if mode in ["eval", "test"]:
# load config
config = AutoConfig.from_pretrained(model_name_or_path)
# randomly initialize model's parameters according to the config
self.plm_encoder = XLMRobertaXLModel(config)
elif mode == "train":
self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path)
else:
raise ValueError()
self.plm_hidden_size = self.plm_encoder.config.hidden_size
# column cls head
self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
self.column_info_cls_head_linear2 = nn.Linear(256, 2)
# column bi-lstm layer
self.column_info_bilstm = nn.LSTM(
input_size = self.plm_hidden_size,
hidden_size = int(self.plm_hidden_size/2),
num_layers = 2,
dropout = 0,
bidirectional = True
)
# linear layer after column bi-lstm layer
self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
# table cls head
self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
self.table_name_cls_head_linear2 = nn.Linear(256, 2)
# table bi-lstm pooling layer
self.table_name_bilstm = nn.LSTM(
input_size = self.plm_hidden_size,
hidden_size = int(self.plm_hidden_size/2),
num_layers = 2,
dropout = 0,
bidirectional = True
)
# linear layer after table bi-lstm layer
self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
# activation function
self.leakyrelu = nn.LeakyReLU()
self.tanh = nn.Tanh()
# table-column cross-attention layer
self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
# dropout function, p=0.2 means randomly set 20% neurons to 0
self.dropout = nn.Dropout(p = 0.2)
def table_column_cross_attention(
self,
table_name_embeddings_in_one_db,
column_info_embeddings_in_one_db,
column_number_in_each_table
):
table_num = table_name_embeddings_in_one_db.shape[0]
table_name_embedding_attn_list = []
for table_id in range(table_num):
table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
table_name_embedding,
column_info_embeddings_in_one_table,
column_info_embeddings_in_one_table
)
table_name_embedding_attn_list.append(table_name_embedding_attn)
# residual connection
table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
# row-wise L2 norm
table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
return table_name_embeddings_in_one_db
def table_column_cls(
self,
encoder_input_ids,
encoder_input_attention_mask,
batch_aligned_column_info_ids,
batch_aligned_table_name_ids,
batch_column_number_in_each_table
):
batch_size = encoder_input_ids.shape[0]
encoder_output = self.plm_encoder(
input_ids = encoder_input_ids,
attention_mask = encoder_input_attention_mask,
return_dict = True
) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
# handle each data in current batch
for batch_id in range(batch_size):
column_number_in_each_table = batch_column_number_in_each_table[batch_id]
sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
# obtain table ids for each table
aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
# obtain column ids for each column
aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
table_name_embedding_list, column_info_embedding_list = [], []
# obtain table embedding via bi-lstm pooling + a non-linear layer
for table_name_ids in aligned_table_name_ids:
table_name_embeddings = sequence_embeddings[table_name_ids, :]
# BiLSTM pooling
output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
table_name_embedding_list.append(table_name_embedding)
table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
# non-linear mlp layer
table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
# obtain column embedding via bi-lstm pooling + a non-linear layer
for column_info_ids in aligned_column_info_ids:
column_info_embeddings = sequence_embeddings[column_info_ids, :]
# BiLSTM pooling
output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
column_info_embedding_list.append(column_info_embedding)
column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
# non-linear mlp layer
column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
# table-column (tc) cross-attention
table_name_embeddings_in_one_db = self.table_column_cross_attention(
table_name_embeddings_in_one_db,
column_info_embeddings_in_one_db,
column_number_in_each_table
)
# calculate table 0-1 logits
table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
# calculate column 0-1 logits
column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
batch_table_name_cls_logits.append(table_name_cls_logits)
batch_column_info_cls_logits.append(column_info_cls_logits)
return batch_table_name_cls_logits, batch_column_info_cls_logits
def forward(
self,
encoder_input_ids,
encoder_attention_mask,
batch_aligned_column_info_ids,
batch_aligned_table_name_ids,
batch_column_number_in_each_table,
):
batch_table_name_cls_logits, batch_column_info_cls_logits \
= self.table_column_cls(
encoder_input_ids,
encoder_attention_mask,
batch_aligned_column_info_ids,
batch_aligned_table_name_ids,
batch_column_number_in_each_table
)
return {
"batch_table_name_cls_logits" : batch_table_name_cls_logits,
"batch_column_info_cls_logits": batch_column_info_cls_logits
} |