|
import json |
|
|
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("ghosh-r/bangla-gpt2") |
|
|
|
train_path = 'train.txt' |
|
test_path = 'test.txt' |
|
|
|
from transformers import TextDataset,DataCollatorForLanguageModeling |
|
|
|
def load_dataset(train_path,test_path,tokenizer): |
|
train_dataset = TextDataset( |
|
tokenizer=tokenizer, |
|
file_path=train_path, |
|
block_size=128) |
|
|
|
test_dataset = TextDataset( |
|
tokenizer=tokenizer, |
|
file_path=test_path, |
|
block_size=128) |
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, mlm=False, |
|
) |
|
return train_dataset, test_dataset, data_collator |
|
|
|
train_dataset,test_dataset,data_collator = load_dataset(train_path,test_path,tokenizer) |
|
|
|
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead |
|
|
|
model = AutoModelWithLMHead.from_pretrained("ghosh-r/bangla-gpt2") |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./bn-poets", |
|
overwrite_output_dir=True, |
|
num_train_epochs=3, |
|
per_device_train_batch_size=32, |
|
per_device_eval_batch_size=64, |
|
eval_steps = 400, |
|
save_steps=800, |
|
warmup_steps=500, |
|
prediction_loss_only=True, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
) |
|
|
|
trainer.train() |
|
|
|
trainer.save_model() |
|
tokenizer.save_pretrained('./bn-poets') |