metadata
language:
- ja
license: mit
tags:
- bart
- pytorch
datasets:
- wikipedia
bart-large-japanese
This model is converted from the original Japanese BART Pretrained model released by Kyoto University.
Both the encoder and decoder outputs are identical to the original Fairseq model.
How to use the model
The input text should be tokenized by BartJapaneseTokenizer.
Tokenizer requirements:
Simple FillMaskPipeline
from transformers import AutoModelForSeq2SeqLM, pipeline
from tokenization_bart_japanese import BartJapaneseTokenizer
model_name = "Formzu/bart-large-japanese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = BartJapaneseTokenizer.from_pretrained(model_name)
masked_text = "ๅคฉๆฐใ<mask>ใใๆฃๆญฉใใพใใใใ"
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
out = fill_mask(masked_text)
print(out)
# [{'score': 0.03228279948234558, 'token': 2566, 'token_str': 'ใใ', 'sequence': 'ๅคฉๆฐ ใ ใใ ใใ ๆฃๆญฉ ใ ใพใใใ ใ'},
# {'score': 0.023878807201981544, 'token': 27365, 'token_str': 'ๆดใ', 'sequence': 'ๅคฉๆฐ ใ ๆดใ ใใ ๆฃๆญฉ ใ ใพใใใ ใ'},
# {'score': 0.020059829577803612, 'token': 267, 'token_str': 'ๅ', 'sequence': 'ๅคฉๆฐ ใ ๅ ใใ ๆฃๆญฉ ใ ใพใใใ ใ'},
# {'score': 0.013921134173870087, 'token': 17, 'token_str': 'ใช', 'sequence': 'ๅคฉๆฐ ใ ใช ใใ ๆฃๆญฉ ใ ใพใใใ ใ'},
# {'score': 0.013069136068224907, 'token': 1718, 'token_str': 'ใใ', 'sequence': 'ๅคฉๆฐ ใ ใใ ใใ ๆฃๆญฉ ใ ใพใใใ ใ'}]
Text Generation
from transformers import AutoModelForSeq2SeqLM
from tokenization_bart_japanese import BartJapaneseTokenizer
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "Formzu/bart-large-japanese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = BartJapaneseTokenizer.from_pretrained(model_name)
masked_text = "ๅคฉๆฐใ<mask>ใใๆฃๆญฉใใพใใใใ"
inp = tokenizer(masked_text, return_tensors='pt').to(device)
out = model.generate(**inp, num_beams=1, min_length=0, max_length=20, early_stopping=True, no_repeat_ngram_size=2)
res = "".join(tokenizer.decode(out.squeeze(0).tolist(), skip_special_tokens=True).split(" "))
print(res)
# ๅคฉๆฐใใใใใๆฃๆญฉใใพใใใใๅคฉๆฐใฎใใใธใใใใใใใใ
Framework versions
- Transformers 4.21.2
- Pytorch 1.12.1+cu116
- Tokenizers 0.12.1