Spaces:
Runtime error
Runtime error
File size: 929 Bytes
44db343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from models.transformer import TransformerWithTR
from models.collator import *
from transformers import AutoTokenizer
import transformers
from models.tokenizer import TokenAligner
from dataset.vocab import Vocab
class ModelWrapper:
def __init__(self, model, vocab: Vocab):
self.model_name = model
if model == "tfmwtr":
self.tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-word-base")
self.tokenAligner = TokenAligner(self.tokenizer, vocab)
self.bart = transformers.MBartForConditionalGeneration.from_pretrained("vinai/bartpho-word-base")
self.model = TransformerWithTR(self.bart, self.tokenizer.pad_token_id)
self.collator = DataCollatorForCharacterTransformer(self.tokenAligner)
# self.model.resize_token_embeddings(self.tokenAligner)
else:
raise(Exception(f"Model {model} isn't implemented!"))
|