|
{ |
|
"dataset_reader": { |
|
"class_name": "basic_classification_reader", |
|
"x": "text", |
|
"y": "sentiment", |
|
"data_path": "/content/drive/MyDrive/BERT/train/", |
|
"train": "train.csv", |
|
"valid": "valid.csv" |
|
}, |
|
"dataset_iterator": { |
|
"class_name": "basic_classification_iterator", |
|
"seed": 42 |
|
}, |
|
"chainer": { |
|
"in": [ |
|
"x" |
|
], |
|
"in_y": [ |
|
"y" |
|
], |
|
"pipe": [ |
|
{ |
|
"id": "classes_vocab", |
|
"class_name": "simple_vocab", |
|
"fit_on": [ |
|
"y" |
|
], |
|
"save_path": "/content/drive/MyDrive/BERT/sentiment_bert_model/classes.dict", |
|
"load_path": "/content/drive/MyDrive/BERT/sentiment_bert_model/classes.dict", |
|
"in": "y", |
|
"out": "y_ids" |
|
}, |
|
{ |
|
"class_name": "torch_transformers_preprocessor", |
|
"vocab_file": "/content/drive/MyDrive/BERT/rubert-base-cased-sentiment/", |
|
"do_lower_case": true, |
|
"max_seq_length": 512, |
|
"in": [ |
|
"x" |
|
], |
|
"out": [ |
|
"bert_features" |
|
] |
|
}, |
|
{ |
|
"in": "y_ids", |
|
"out": "y_onehot", |
|
"class_name": "one_hotter", |
|
"depth": "#classes_vocab.len", |
|
"single_vector": true |
|
}, |
|
{ |
|
"class_name": "torch_transformers_classifier", |
|
"n_classes": 3, |
|
"return_probas": true, |
|
"pretrained_bert": "/content/drive/MyDrive/BERT/rubert-base-cased-sentiment/", |
|
"save_path": "/content/drive/MyDrive/BERT/sentiment_bert_model/model", |
|
"load_path": "/content/drive/MyDrive/BERT/sentiment_bert_model/model", |
|
"optimizer": "AdamW", |
|
"optimizer_parameters": { |
|
"lr": 1e-05 |
|
}, |
|
"learning_rate_drop_patience": 5, |
|
"learning_rate_drop_div": 2.0, |
|
"in": [ |
|
"bert_features" |
|
], |
|
"in_y": [ |
|
"y_ids" |
|
], |
|
"out": [ |
|
"y_pred_probas" |
|
] |
|
}, |
|
{ |
|
"in": "y_pred_probas", |
|
"out": "y_pred_ids", |
|
"class_name": "proba2labels", |
|
"max_proba": true |
|
}, |
|
{ |
|
"in": "y_pred_ids", |
|
"out": "y_pred_labels", |
|
"ref": "classes_vocab" |
|
} |
|
], |
|
"out": [ |
|
"y_pred_labels" |
|
] |
|
}, |
|
"train": { |
|
"epochs": 5, |
|
"batch_size": 8, |
|
"metrics": [ |
|
"accuracy", |
|
"f1_macro", |
|
"f1_weighted", |
|
{ |
|
"name": "roc_auc", |
|
"inputs": [ |
|
"y_onehot", |
|
"y_pred_probas" |
|
] |
|
} |
|
], |
|
"validation_patience": 2, |
|
"val_every_n_epochs": 1, |
|
"log_every_n_epochs": 1, |
|
"show_examples": false, |
|
"evaluation_targets": [ |
|
"train", |
|
"valid" |
|
], |
|
"class_name": "nn_trainer" |
|
} |
|
} |
|
|