import torch import torch.nn as nn from transformers import AutoConfig, XLMRobertaXLModel class SchemaItemClassifier(nn.Module): def __init__(self, model_name_or_path, mode): super(SchemaItemClassifier, self).__init__() if mode in ["eval", "test"]: # load config config = AutoConfig.from_pretrained(model_name_or_path) # randomly initialize model's parameters according to the config self.plm_encoder = XLMRobertaXLModel(config) elif mode == "train": self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path) else: raise ValueError() self.plm_hidden_size = self.plm_encoder.config.hidden_size # column cls head self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) self.column_info_cls_head_linear2 = nn.Linear(256, 2) # column bi-lstm layer self.column_info_bilstm = nn.LSTM( input_size = self.plm_hidden_size, hidden_size = int(self.plm_hidden_size/2), num_layers = 2, dropout = 0, bidirectional = True ) # linear layer after column bi-lstm layer self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) # table cls head self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) self.table_name_cls_head_linear2 = nn.Linear(256, 2) # table bi-lstm pooling layer self.table_name_bilstm = nn.LSTM( input_size = self.plm_hidden_size, hidden_size = int(self.plm_hidden_size/2), num_layers = 2, dropout = 0, bidirectional = True ) # linear layer after table bi-lstm layer self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) # activation function self.leakyrelu = nn.LeakyReLU() self.tanh = nn.Tanh() # table-column cross-attention layer self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8) # dropout function, p=0.2 means randomly set 20% neurons to 0 self.dropout = nn.Dropout(p = 0.2) def table_column_cross_attention( self, table_name_embeddings_in_one_db, column_info_embeddings_in_one_db, column_number_in_each_table ): table_num = table_name_embeddings_in_one_db.shape[0] table_name_embedding_attn_list = [] for table_id in range(table_num): table_name_embedding = table_name_embeddings_in_one_db[[table_id], :] column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[ sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :] table_name_embedding_attn, _ = self.table_column_cross_attention_layer( table_name_embedding, column_info_embeddings_in_one_table, column_info_embeddings_in_one_table ) table_name_embedding_attn_list.append(table_name_embedding_attn) # residual connection table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0) # row-wise L2 norm table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1) return table_name_embeddings_in_one_db def table_column_cls( self, encoder_input_ids, encoder_input_attention_mask, batch_aligned_column_info_ids, batch_aligned_table_name_ids, batch_column_number_in_each_table ): batch_size = encoder_input_ids.shape[0] encoder_output = self.plm_encoder( input_ids = encoder_input_ids, attention_mask = encoder_input_attention_mask, return_dict = True ) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size) batch_table_name_cls_logits, batch_column_info_cls_logits = [], [] # handle each data in current batch for batch_id in range(batch_size): column_number_in_each_table = batch_column_number_in_each_table[batch_id] sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size) # obtain table ids for each table aligned_table_name_ids = batch_aligned_table_name_ids[batch_id] # obtain column ids for each column aligned_column_info_ids = batch_aligned_column_info_ids[batch_id] table_name_embedding_list, column_info_embedding_list = [], [] # obtain table embedding via bi-lstm pooling + a non-linear layer for table_name_ids in aligned_table_name_ids: table_name_embeddings = sequence_embeddings[table_name_ids, :] # BiLSTM pooling output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings) table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size) table_name_embedding_list.append(table_name_embedding) table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0) # non-linear mlp layer table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db)) # obtain column embedding via bi-lstm pooling + a non-linear layer for column_info_ids in aligned_column_info_ids: column_info_embeddings = sequence_embeddings[column_info_ids, :] # BiLSTM pooling output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings) column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size) column_info_embedding_list.append(column_info_embedding) column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0) # non-linear mlp layer column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db)) # table-column (tc) cross-attention table_name_embeddings_in_one_db = self.table_column_cross_attention( table_name_embeddings_in_one_db, column_info_embeddings_in_one_db, column_number_in_each_table ) # calculate table 0-1 logits table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db) table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db)) table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db) # calculate column 0-1 logits column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db) column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db)) column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db) batch_table_name_cls_logits.append(table_name_cls_logits) batch_column_info_cls_logits.append(column_info_cls_logits) return batch_table_name_cls_logits, batch_column_info_cls_logits def forward( self, encoder_input_ids, encoder_attention_mask, batch_aligned_column_info_ids, batch_aligned_table_name_ids, batch_column_number_in_each_table, ): batch_table_name_cls_logits, batch_column_info_cls_logits \ = self.table_column_cls( encoder_input_ids, encoder_attention_mask, batch_aligned_column_info_ids, batch_aligned_table_name_ids, batch_column_number_in_each_table ) return { "batch_table_name_cls_logits" : batch_table_name_cls_logits, "batch_column_info_cls_logits": batch_column_info_cls_logits }