File size: 2,070 Bytes
44db343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import MBartForConditionalGeneration, AutoTokenizer
from params import DEVICE
from models.tokenizer import TokenAligner
from dataset.vocab import Vocab

class TransformerWithTR(nn.Module):
    def __init__(self, bart_model, padding_index) -> None:
        super().__init__()
        self.bart: MBartForConditionalGeneration= bart_model
        self.pad_token_id = padding_index


    def forward(self, src_ids, attn_masks, labels = None):
        labels[labels == self.pad_token_id] = -100
        src_ids = src_ids.to(DEVICE)
        labels = labels.to(DEVICE)
        attn_masks = attn_masks.to(DEVICE)
        out = dict()

        output = self.bart(input_ids = src_ids, attention_mask = attn_masks,
            labels = labels)
        logits = output['logits']
        out['loss'] = output['loss']
        out['logits'] = logits
        probs = F.softmax(logits, dim = -1)
        preds = torch.argmax(probs, dim = -1)
        out['preds'] = preds.cpu().detach().numpy()
        return out

    def resize_token_embeddings(self, tokenAligner: TokenAligner):
        vocab: Vocab = tokenAligner.vocab
        tokenizer: AutoTokenizer = tokenAligner.tokenizer
        char_vocab = []
        for i, key in enumerate(vocab.chartoken2idx.keys()):
            if i < 4:
                continue
            char_vocab.append(key)
            char_vocab.append(key + "@@")
        tokenizer.add_tokens(char_vocab)
        self.bart.resize_token_embeddings(len(tokenizer.get_vocab()))
        print("Resized token embeddings!")
        return
    
    def inference(self, src_ids, num_beams = 2, tokenAligner: TokenAligner = None):
        assert tokenAligner != None
        src_ids = src_ids.to(DEVICE)
        output = self.bart.generate(src_ids, num_beams=num_beams, max_new_tokens = 256)
        predict_text = tokenAligner.tokenizer.batch_decode(output, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces = False)
        return predict_text