spelling-correction / models /transformer.py
hoang1007's picture
Upload 69 files
44db343
raw
history blame
2.07 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import MBartForConditionalGeneration, AutoTokenizer
from params import DEVICE
from models.tokenizer import TokenAligner
from dataset.vocab import Vocab
class TransformerWithTR(nn.Module):
def __init__(self, bart_model, padding_index) -> None:
super().__init__()
self.bart: MBartForConditionalGeneration= bart_model
self.pad_token_id = padding_index
def forward(self, src_ids, attn_masks, labels = None):
labels[labels == self.pad_token_id] = -100
src_ids = src_ids.to(DEVICE)
labels = labels.to(DEVICE)
attn_masks = attn_masks.to(DEVICE)
out = dict()
output = self.bart(input_ids = src_ids, attention_mask = attn_masks,
labels = labels)
logits = output['logits']
out['loss'] = output['loss']
out['logits'] = logits
probs = F.softmax(logits, dim = -1)
preds = torch.argmax(probs, dim = -1)
out['preds'] = preds.cpu().detach().numpy()
return out
def resize_token_embeddings(self, tokenAligner: TokenAligner):
vocab: Vocab = tokenAligner.vocab
tokenizer: AutoTokenizer = tokenAligner.tokenizer
char_vocab = []
for i, key in enumerate(vocab.chartoken2idx.keys()):
if i < 4:
continue
char_vocab.append(key)
char_vocab.append(key + "@@")
tokenizer.add_tokens(char_vocab)
self.bart.resize_token_embeddings(len(tokenizer.get_vocab()))
print("Resized token embeddings!")
return
def inference(self, src_ids, num_beams = 2, tokenAligner: TokenAligner = None):
assert tokenAligner != None
src_ids = src_ids.to(DEVICE)
output = self.bart.generate(src_ids, num_beams=num_beams, max_new_tokens = 256)
predict_text = tokenAligner.tokenizer.batch_decode(output,
skip_special_tokens=True,
clean_up_tokenization_spaces = False)
return predict_text