bn-poets / train.py
Ritobrata Ghosh
initial commit
78ab398
import json
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("ghosh-r/bangla-gpt2")
train_path = 'train.txt'
test_path = 'test.txt'
from transformers import TextDataset,DataCollatorForLanguageModeling
def load_dataset(train_path,test_path,tokenizer):
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path=train_path,
block_size=128)
test_dataset = TextDataset(
tokenizer=tokenizer,
file_path=test_path,
block_size=128)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False,
)
return train_dataset, test_dataset, data_collator
train_dataset,test_dataset,data_collator = load_dataset(train_path,test_path,tokenizer)
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead
model = AutoModelWithLMHead.from_pretrained("ghosh-r/bangla-gpt2")
training_args = TrainingArguments(
output_dir="./bn-poets",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
eval_steps = 400,
save_steps=800,
warmup_steps=500,
prediction_loss_only=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
trainer.save_model()
tokenizer.save_pretrained('./bn-poets')