schema_filter / utils /classifier_model.py
justinsiow's picture
Uploaded Utils, Pycache and Python Files
1e712af verified
import torch
import torch.nn as nn
from transformers import AutoConfig, XLMRobertaXLModel
class SchemaItemClassifier(nn.Module):
def __init__(self, model_name_or_path, mode):
super(SchemaItemClassifier, self).__init__()
if mode in ["eval", "test"]:
# load config
config = AutoConfig.from_pretrained(model_name_or_path)
# randomly initialize model's parameters according to the config
self.plm_encoder = XLMRobertaXLModel(config)
elif mode == "train":
self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path)
else:
raise ValueError()
self.plm_hidden_size = self.plm_encoder.config.hidden_size
# column cls head
self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
self.column_info_cls_head_linear2 = nn.Linear(256, 2)
# column bi-lstm layer
self.column_info_bilstm = nn.LSTM(
input_size = self.plm_hidden_size,
hidden_size = int(self.plm_hidden_size/2),
num_layers = 2,
dropout = 0,
bidirectional = True
)
# linear layer after column bi-lstm layer
self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
# table cls head
self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
self.table_name_cls_head_linear2 = nn.Linear(256, 2)
# table bi-lstm pooling layer
self.table_name_bilstm = nn.LSTM(
input_size = self.plm_hidden_size,
hidden_size = int(self.plm_hidden_size/2),
num_layers = 2,
dropout = 0,
bidirectional = True
)
# linear layer after table bi-lstm layer
self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
# activation function
self.leakyrelu = nn.LeakyReLU()
self.tanh = nn.Tanh()
# table-column cross-attention layer
self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
# dropout function, p=0.2 means randomly set 20% neurons to 0
self.dropout = nn.Dropout(p = 0.2)
def table_column_cross_attention(
self,
table_name_embeddings_in_one_db,
column_info_embeddings_in_one_db,
column_number_in_each_table
):
table_num = table_name_embeddings_in_one_db.shape[0]
table_name_embedding_attn_list = []
for table_id in range(table_num):
table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
table_name_embedding,
column_info_embeddings_in_one_table,
column_info_embeddings_in_one_table
)
table_name_embedding_attn_list.append(table_name_embedding_attn)
# residual connection
table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
# row-wise L2 norm
table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
return table_name_embeddings_in_one_db
def table_column_cls(
self,
encoder_input_ids,
encoder_input_attention_mask,
batch_aligned_column_info_ids,
batch_aligned_table_name_ids,
batch_column_number_in_each_table
):
batch_size = encoder_input_ids.shape[0]
encoder_output = self.plm_encoder(
input_ids = encoder_input_ids,
attention_mask = encoder_input_attention_mask,
return_dict = True
) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
# handle each data in current batch
for batch_id in range(batch_size):
column_number_in_each_table = batch_column_number_in_each_table[batch_id]
sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
# obtain table ids for each table
aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
# obtain column ids for each column
aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
table_name_embedding_list, column_info_embedding_list = [], []
# obtain table embedding via bi-lstm pooling + a non-linear layer
for table_name_ids in aligned_table_name_ids:
table_name_embeddings = sequence_embeddings[table_name_ids, :]
# BiLSTM pooling
output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
table_name_embedding_list.append(table_name_embedding)
table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
# non-linear mlp layer
table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
# obtain column embedding via bi-lstm pooling + a non-linear layer
for column_info_ids in aligned_column_info_ids:
column_info_embeddings = sequence_embeddings[column_info_ids, :]
# BiLSTM pooling
output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
column_info_embedding_list.append(column_info_embedding)
column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
# non-linear mlp layer
column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
# table-column (tc) cross-attention
table_name_embeddings_in_one_db = self.table_column_cross_attention(
table_name_embeddings_in_one_db,
column_info_embeddings_in_one_db,
column_number_in_each_table
)
# calculate table 0-1 logits
table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
# calculate column 0-1 logits
column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
batch_table_name_cls_logits.append(table_name_cls_logits)
batch_column_info_cls_logits.append(column_info_cls_logits)
return batch_table_name_cls_logits, batch_column_info_cls_logits
def forward(
self,
encoder_input_ids,
encoder_attention_mask,
batch_aligned_column_info_ids,
batch_aligned_table_name_ids,
batch_column_number_in_each_table,
):
batch_table_name_cls_logits, batch_column_info_cls_logits \
= self.table_column_cls(
encoder_input_ids,
encoder_attention_mask,
batch_aligned_column_info_ids,
batch_aligned_table_name_ids,
batch_column_number_in_each_table
)
return {
"batch_table_name_cls_logits" : batch_table_name_cls_logits,
"batch_column_info_cls_logits": batch_column_info_cls_logits
}