Spaces:
Runtime error
Runtime error
from models.transformer import TransformerWithTR | |
from models.collator import * | |
from transformers import AutoTokenizer | |
import transformers | |
from models.tokenizer import TokenAligner | |
from dataset.vocab import Vocab | |
class ModelWrapper: | |
def __init__(self, model, vocab: Vocab): | |
self.model_name = model | |
if model == "tfmwtr": | |
self.tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-word-base") | |
self.tokenAligner = TokenAligner(self.tokenizer, vocab) | |
self.bart = transformers.MBartForConditionalGeneration.from_pretrained("vinai/bartpho-word-base") | |
self.model = TransformerWithTR(self.bart, self.tokenizer.pad_token_id) | |
self.collator = DataCollatorForCharacterTransformer(self.tokenAligner) | |
# self.model.resize_token_embeddings(self.tokenAligner) | |
else: | |
raise(Exception(f"Model {model} isn't implemented!")) | |