{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d090c366-23e5-4221-a868-f290eefcedc2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"google/boolq\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a6bad310-9514-4468-bdca-673b30dfd473", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "tokenizer=AutoTokenizer.from_pretrained(\"bert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": null, "id": "013559ce-c991-4836-922c-5f9201265c66", "metadata": {}, "outputs": [], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "38aac997-3d15-4e61-b80c-c1a4fff0b525", "metadata": {}, "outputs": [], "source": [ "dataset[\"train\"][0]" ] }, { "cell_type": "code", "execution_count": null, "id": "f4d214cd-2fef-4778-bc3a-cb4e1c907515", "metadata": {}, "outputs": [], "source": [ "def encode_question_context_pairs(example):\n", " text=f'{example[\"question\"]} [SEP] {example[\"passage\"]}'\n", " label= 0 if not example[\"answer\"] else 1\n", " inputs=tokenizer(text,truncation=True)\n", " inputs[\"labels\"]=[float(label)]\n", " return inputs" ] }, { "cell_type": "code", "execution_count": null, "id": "6fa2aa41-6286-4a69-ba23-90482d98f494", "metadata": {}, "outputs": [], "source": [ "train_dataset=dataset[\"train\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)" ] }, { "cell_type": "code", "execution_count": null, "id": "309bee55-b698-4c66-990d-beb00ac52746", "metadata": {}, "outputs": [], "source": [ "validation_dataset=dataset[\"validation\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)" ] }, { "cell_type": "code", "execution_count": null, "id": "bf95690a-4ed4-4635-9b39-12bc4b486b5f", "metadata": {}, "outputs": [], "source": [ "# train_dataset['labels']" ] }, { "cell_type": "code", "execution_count": null, "id": "00c07517-6976-4553-8188-2b7f4078adf3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1371cc4a-3f0e-4e84-939b-218b570c0b6b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "85c9ccea-f788-4025-b185-c32c6fa51c46", "metadata": {}, "outputs": [], "source": [ "# tokenizer(\"question\",\"answer\",max_length=512,padding=\"max_length\",truncation=\"only_second\",)" ] }, { "cell_type": "code", "execution_count": null, "id": "30a82635-f956-404d-a95e-db753f7e07b7", "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorWithPadding\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "id": "22d43e81-1739-443f-95fb-ee98b10a3a0b", "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "accuracy = evaluate.load(\"accuracy\")" ] }, { "cell_type": "code", "execution_count": null, "id": "23fa9362-aa3d-4155-85a5-6caa6635c9f8", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = np.where(predictions<0.5,0,1)\n", " return accuracy.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "code", "execution_count": null, "id": "e476c76f-21b6-4844-a6a5-29f18b4f6099", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " \"bert-base-uncased\", num_labels=1,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "5a359a0d-7563-4f4e-b4d4-03e6c601fc2f", "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"./\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=4,\n", " weight_decay=0.01,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " load_best_model_at_end=True,\n", " gradient_accumulation_steps=4,\n", " logging_steps=50,\n", " seed=42,\n", " adam_beta1= 0.9,\n", " adam_beta2= 0.999,\n", " adam_epsilon= 1e-08,\n", " report_to=\"tensorboard\",\n", " push_to_hub=True,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=validation_dataset,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "# trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "0bc0fca5-d298-40d3-a80b-035a05fe6e1f", "metadata": {}, "outputs": [], "source": [ "model.save_pretrained(training_args.output_dir)\n", "tokenizer.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "c96926e2-04c1-4e33-b83f-dc2b9c4d5b08", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "75e96eb2-0d8e-4e5f-8844-6abce16bd1cb", "metadata": {}, "outputs": [], "source": [ "kwargs = {\n", " \"dataset_tags\": \"google/boolq\",\n", " \"dataset\": \"boolq\", # a 'pretty' name for the training dataset\n", " \"language\": \"en\",\n", " \"model_name\": \"Bert Base Uncased Boolean Question Answer model\", # a 'pretty' name for your model\n", " \"finetuned_from\": \"bert-base-uncased\",\n", " \"tasks\": \"text-classification\",\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "ba5e73bd-d154-43ce-a869-f0f57045a386", "metadata": {}, "outputs": [], "source": [ "trainer.push_to_hub(**kwargs)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }