roberta-base-suicide-prediction-phr
This model is a fine-tuned version of roberta-base on this dataset sourced from Reddit. It achieves the following results on the evaluation/validation set:
- Loss: 0.1543
- Accuracy: 0.9652972367116438
- Recall: 0.966571403827834
- Precision: 0.9638169257340242
- F1: 0.9651921995935487
It achieves the following result on validation partition of this updated dataset
- Loss: 0.08761
- Accuracy: 0.97065
- Recall: 0.96652
- Precision: 0.97732
- F1: 0.97189
Model description
This model is a finetune of roberta-base to detect suicidal tendencies in a given text.
Training and evaluation data
- The dataset is sourced from Reddit and is available on Kaggle.
- The dataset contains text with binary labels for suicide or non-suicide.
- The dataset was cleaned, and following steps were applied
- Converted to lowercase
- Removed numbers and special characters.
- Removed URLs, Emojis and accented characters.
- Removed any word contractions.
- Remove any extra white spaces and any extra spaces after a single space.
- Removed any consecutive characters repeated more than 3 times.
- Tokenised the text, then lemmatized it and then removed the stopwords (excluding not).
- The cleaned dataset can be found here
- The evaluation set had ~23000 samples, while the training set had ~186k samples, i.e. a 80:10:10 (train:test:val) split.
Training procedure
- The model was trained on an RTXA5000 GPU.
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 2e-05
- train_batch_size: 16
- eval_batch_size: 16
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 3
Training results
Training Loss | Epoch | Step | Validation Loss | Accuracy | Recall | Precision | F1 |
---|---|---|---|---|---|---|---|
0.2023 | 0.09 | 1000 | 0.1868 | {'accuracy': 0.9415010561710566} | {'recall': 0.9389451805663809} | {'precision': 0.943274752044545} | {'f1': 0.9411049867627274} |
0.1792 | 0.17 | 2000 | 0.1465 | {'accuracy': 0.9528387291460103} | {'recall': 0.9615484541439335} | {'precision': 0.9446949714966392} | {'f1': 0.9530472103004292} |
0.1596 | 0.26 | 3000 | 0.1871 | {'accuracy': 0.9523645298961072} | {'recall': 0.9399844115354637} | {'precision': 0.9634297887448962} | {'f1': 0.9515627054749485} |
0.1534 | 0.34 | 4000 | 0.1563 | {'accuracy': 0.9518041126007674} | {'recall': 0.974971854161254} | {'precision': 0.9314139157772814} | {'f1': 0.9526952695269527} |
0.1553 | 0.43 | 5000 | 0.1691 | {'accuracy': 0.9513730223735828} | {'recall': 0.93141075604053} | {'precision': 0.9697051663510955} | {'f1': 0.950172276702889} |
0.1537 | 0.52 | 6000 | 0.1347 | {'accuracy': 0.9568478682588266} | {'recall': 0.9644063393089114} | {'precision': 0.9496844618795839} | {'f1': 0.9569887852876723} |
0.1515 | 0.6 | 7000 | 0.1276 | {'accuracy': 0.9565461050997974} | {'recall': 0.9426690915389279} | {'precision': 0.9691924138545098} | {'f1': 0.9557467732022126} |
0.1453 | 0.69 | 8000 | 0.1351 | {'accuracy': 0.960210372030866} | {'recall': 0.9589503767212263} | {'precision': 0.961031070994619} | {'f1': 0.959989596428107} |
0.1526 | 0.78 | 9000 | 0.1423 | {'accuracy': 0.9610725524852352} | {'recall': 0.9612020438209059} | {'precision': 0.9606196988056085} | {'f1': 0.9609107830829834} |
0.1437 | 0.86 | 10000 | 0.1365 | {'accuracy': 0.9599948269172738} | {'recall': 0.9625010825322594} | {'precision': 0.9573606684468946} | {'f1': 0.9599239937813093} |
0.1317 | 0.95 | 11000 | 0.1275 | {'accuracy': 0.9616760788032935} | {'recall': 0.9653589676972374} | {'precision': 0.9579752492265383} | {'f1': 0.9616529353405513} |
0.125 | 1.03 | 12000 | 0.1428 | {'accuracy': 0.9608138983489244} | {'recall': 0.9522819780029445} | {'precision': 0.9684692619341201} | {'f1': 0.9603074101567617} |
0.1135 | 1.12 | 13000 | 0.1627 | {'accuracy': 0.960770789326206} | {'recall': 0.9544470425218672} | {'precision': 0.966330556773345} | {'f1': 0.9603520390379923} |
0.1096 | 1.21 | 14000 | 0.1240 | {'accuracy': 0.9624520412122257} | {'recall': 0.9566987096215467} | {'precision': 0.9675074443860571} | {'f1': 0.962072719355541} |
0.1213 | 1.29 | 15000 | 0.1502 | {'accuracy': 0.9616760788032935} | {'recall': 0.9659651857625358} | {'precision': 0.9574248927038627} | {'f1': 0.9616760788032936} |
0.1166 | 1.38 | 16000 | 0.1574 | {'accuracy': 0.958873992326594} | {'recall': 0.9438815276695246} | {'precision': 0.9726907630522088} | {'f1': 0.9580696202531646} |
0.1214 | 1.47 | 17000 | 0.1626 | {'accuracy': 0.9562443419407682} | {'recall': 0.9773101238416905} | {'precision': 0.9374480810765908} | {'f1': 0.9569641721433114} |
0.1064 | 1.55 | 18000 | 0.1653 | {'accuracy': 0.9624089321895073} | {'recall': 0.9622412747899888} | {'precision': 0.9622412747899888} | {'f1': 0.9622412747899888} |
0.1046 | 1.64 | 19000 | 0.1608 | {'accuracy': 0.9640039660300901} | {'recall': 0.9697756993158396} | {'precision': 0.9584046559397467} | {'f1': 0.9640566484438896} |
0.1043 | 1.72 | 20000 | 0.1556 | {'accuracy': 0.960770789326206} | {'recall': 0.9493374902572097} | {'precision': 0.9712058119961017} | {'f1': 0.9601471489883507} |
0.0995 | 1.81 | 21000 | 0.1646 | {'accuracy': 0.9602534810535845} | {'recall': 0.9752316619035247} | {'precision': 0.9465411448264268} | {'f1': 0.9606722402320423} |
0.1065 | 1.9 | 22000 | 0.1721 | {'accuracy': 0.9627106953485365} | {'recall': 0.9710747380271932} | {'precision': 0.9547854223433242} | {'f1': 0.9628611910179897} |
0.1204 | 1.98 | 23000 | 0.1214 | {'accuracy': 0.9629693494848471} | {'recall': 0.961028838659392} | {'precision': 0.9644533286980705} | {'f1': 0.9627380384331756} |
0.0852 | 2.07 | 24000 | 0.1583 | {'accuracy': 0.9643919472345562} | {'recall': 0.9624144799515025} | {'precision': 0.9659278574532811} | {'f1': 0.9641679680721846} |
0.0812 | 2.16 | 25000 | 0.1594 | {'accuracy': 0.9635728758029055} | {'recall': 0.9572183251060882} | {'precision': 0.9692213258505787} | {'f1': 0.9631824321380331} |
0.0803 | 2.24 | 26000 | 0.1629 | {'accuracy': 0.9639177479846532} | {'recall': 0.9608556334978783} | {'precision': 0.9664634146341463} | {'f1': 0.963651365787988} |
0.0832 | 2.33 | 27000 | 0.1570 | {'accuracy': 0.9631417855757209} | {'recall': 0.9658785831817788} | {'precision': 0.9603065266058206} | {'f1': 0.9630844954881052} |
0.0887 | 2.41 | 28000 | 0.1551 | {'accuracy': 0.9623227141440703} | {'recall': 0.9669178141508616} | {'precision': 0.9577936004117698} | {'f1': 0.9623340803309774} |
0.084 | 2.5 | 29000 | 0.1585 | {'accuracy': 0.9644350562572747} | {'recall': 0.9613752489824197} | {'precision': 0.96698606271777} | {'f1': 0.9641724931602031} |
0.0807 | 2.59 | 30000 | 0.1601 | {'accuracy': 0.9639177479846532} | {'recall': 0.9699489044773534} | {'precision': 0.9580838323353293} | {'f1': 0.9639798597065025} |
0.079 | 2.67 | 31000 | 0.1645 | {'accuracy': 0.9628400224166919} | {'recall': 0.9558326838139777} | {'precision': 0.9690929844586882} | {'f1': 0.9624171607952564} |
0.0913 | 2.76 | 32000 | 0.1560 | {'accuracy': 0.9642626201664009} | {'recall': 0.964752749631939} | {'precision': 0.9635011243729459} | {'f1': 0.9641265307888701} |
0.0927 | 2.85 | 33000 | 0.1491 | {'accuracy': 0.9649523645298961} | {'recall': 0.9659651857625358} | {'precision': 0.9637117677553136} | {'f1': 0.9648371610224472} |
0.0882 | 2.93 | 34000 | 0.1543 | {'accuracy': 0.9652972367116438} | {'recall': 0.966571403827834} | {'precision': 0.9638169257340242} | {'f1': 0.9651921995935487} |
Framework versions
- Transformers 4.31.0
- Pytorch 2.1.0+cu121
- Datasets 2.14.5
- Tokenizers 0.13.3
- Downloads last month
- 52,366
Model tree for vibhorag101/roberta-base-suicide-prediction-phr
Base model
FacebookAI/roberta-baseDataset used to train vibhorag101/roberta-base-suicide-prediction-phr
Evaluation results
- accuracy on Suicide Prediction Datasettest set self-reported0.965
- f1 on Suicide Prediction Datasettest set self-reported0.965
- recall on Suicide Prediction Datasettest set self-reported0.967
- precision on Suicide Prediction Datasettest set self-reported0.964