Upload 61 files
#8
by
shayekh
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- src/datasets/__init__.py +7 -0
- src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_crf_3cls_tokens.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_crf_tokens.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_multi_spans.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_spans.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_tokens.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_tokens_3cls.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_tokens_spans.cpython-38.pyc +0 -0
- src/datasets/toxic_spans_crf_3cls_tokens.py +132 -0
- src/datasets/toxic_spans_crf_tokens.py +111 -0
- src/datasets/toxic_spans_multi_spans.py +237 -0
- src/datasets/toxic_spans_spans.py +238 -0
- src/datasets/toxic_spans_tokens.py +81 -0
- src/datasets/toxic_spans_tokens_3cls.py +102 -0
- src/datasets/toxic_spans_tokens_spans.py +269 -0
- src/models/__init__.py +7 -0
- src/models/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/__pycache__/auto_models.cpython-38.pyc +0 -0
- src/models/__pycache__/bert_crf_token.cpython-38.pyc +0 -0
- src/models/__pycache__/bert_multi_spans.cpython-38.pyc +0 -0
- src/models/__pycache__/bert_token_spans.cpython-38.pyc +0 -0
- src/models/__pycache__/roberta_crf_token.cpython-38.pyc +0 -0
- src/models/__pycache__/roberta_multi_spans.cpython-38.pyc +0 -0
- src/models/__pycache__/roberta_token_spans.cpython-38.pyc +0 -0
- src/models/auto_models.py +6 -0
- src/models/bert_crf_token.py +72 -0
- src/models/bert_multi_spans.py +84 -0
- src/models/bert_token_spans.py +100 -0
- src/models/roberta_crf_token.py +66 -0
- src/models/roberta_multi_spans.py +82 -0
- src/models/roberta_token_spans.py +97 -0
- src/models/two_layer_nn.py +46 -0
- src/modules/__init__.py +0 -0
- src/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- src/modules/__pycache__/embeddings.cpython-38.pyc +0 -0
- src/modules/__pycache__/preprocessors.cpython-38.pyc +0 -0
- src/modules/__pycache__/tokenizers.cpython-38.pyc +0 -0
- src/modules/activations.py +6 -0
- src/modules/embeddings.py +37 -0
- src/modules/losses.py +6 -0
- src/modules/metrics.py +17 -0
- src/modules/optimizers.py +7 -0
- src/modules/preprocessors.py +112 -0
- src/modules/schedulers.py +14 -0
- src/modules/tokenizers.py +107 -0
- src/trainers/__init__.py +0 -0
- src/trainers/base_trainer.py +563 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-38.pyc +0 -0
src/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.datasets.toxic_spans_tokens import *
|
2 |
+
from src.datasets.toxic_spans_tokens_3cls import *
|
3 |
+
from src.datasets.toxic_spans_spans import *
|
4 |
+
from src.datasets.toxic_spans_tokens_spans import *
|
5 |
+
from src.datasets.toxic_spans_multi_spans import *
|
6 |
+
from src.datasets.toxic_spans_crf_tokens import *
|
7 |
+
from src.datasets.toxic_spans_crf_3cls_tokens import *
|
src/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (503 Bytes). View file
|
|
src/datasets/__pycache__/toxic_spans_crf_3cls_tokens.cpython-38.pyc
ADDED
Binary file (2.99 kB). View file
|
|
src/datasets/__pycache__/toxic_spans_crf_tokens.cpython-38.pyc
ADDED
Binary file (2.76 kB). View file
|
|
src/datasets/__pycache__/toxic_spans_multi_spans.cpython-38.pyc
ADDED
Binary file (5.55 kB). View file
|
|
src/datasets/__pycache__/toxic_spans_spans.cpython-38.pyc
ADDED
Binary file (5.2 kB). View file
|
|
src/datasets/__pycache__/toxic_spans_tokens.cpython-38.pyc
ADDED
Binary file (2.35 kB). View file
|
|
src/datasets/__pycache__/toxic_spans_tokens_3cls.cpython-38.pyc
ADDED
Binary file (2.59 kB). View file
|
|
src/datasets/__pycache__/toxic_spans_tokens_spans.cpython-38.pyc
ADDED
Binary file (5.97 kB). View file
|
|
src/datasets/toxic_spans_crf_3cls_tokens.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from datasets import load_dataset
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
@configmapper.map("datasets", "toxic_spans_crf_3cls_tokens")
|
8 |
+
class ToxicSpansCRF3ClsTokenDataset:
|
9 |
+
def __init__(self, config):
|
10 |
+
self.config = config
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
12 |
+
self.config.model_checkpoint_name
|
13 |
+
)
|
14 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
15 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
16 |
+
|
17 |
+
self.tokenized_inputs = self.dataset.map(
|
18 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
19 |
+
)
|
20 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
21 |
+
self.tokenize_for_test, batched=True
|
22 |
+
)
|
23 |
+
|
24 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
25 |
+
tokenized_inputs = self.tokenizer(
|
26 |
+
examples["text"], **self.config.tokenizer_params
|
27 |
+
)
|
28 |
+
|
29 |
+
# tokenized_inputs["text"] = examples["text"]
|
30 |
+
example_spans = []
|
31 |
+
labels = []
|
32 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
33 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
34 |
+
|
35 |
+
## Wrong Code
|
36 |
+
# for i, offset_mapping in enumerate(offsets_mapping):
|
37 |
+
# j = 0
|
38 |
+
# while j < len(offset_mapping): # [tok1, tok2, tok3] [(0,5),(1,4),(5,7)]
|
39 |
+
# if tokenized_inputs["input_ids"][i][j] in [
|
40 |
+
# self.tokenizer.sep_token_id,
|
41 |
+
# self.tokenizer.pad_token_id,
|
42 |
+
# self.tokenizer.cls_token_id,
|
43 |
+
# ]:
|
44 |
+
# j = j + 1
|
45 |
+
# continue
|
46 |
+
# else:
|
47 |
+
# k = j + 1
|
48 |
+
# while self.tokenizer.convert_ids_to_tokens(
|
49 |
+
# tokenized_inputs["input_ids"][i][k]
|
50 |
+
# ).startswith("##"):
|
51 |
+
# offset_mapping[i][j][1] = offset_mapping[i][k][1]
|
52 |
+
# j = k
|
53 |
+
|
54 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
55 |
+
labels.append([])
|
56 |
+
|
57 |
+
spans = eval(examples["spans"][i])
|
58 |
+
Bs = eval(examples["Bs"][i])
|
59 |
+
Is = eval(examples["Is"][i])
|
60 |
+
|
61 |
+
example_spans.append(spans)
|
62 |
+
# cls_label = 2 ## DUMMY LABEL
|
63 |
+
cls_label = 3 ## DUMMY LABEL
|
64 |
+
for j, offsets in enumerate(offset_mapping):
|
65 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
66 |
+
self.tokenizer.sep_token_id,
|
67 |
+
self.tokenizer.pad_token_id,
|
68 |
+
]:
|
69 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
70 |
+
|
71 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
72 |
+
labels[-1].append(cls_label)
|
73 |
+
prediction_mask[i][j] = 1
|
74 |
+
|
75 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0:
|
76 |
+
# labels[-1].append(2) ## DUMMY
|
77 |
+
labels[-1].append(cls_label) ## DUMMY
|
78 |
+
|
79 |
+
else:
|
80 |
+
# toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
81 |
+
# ## If any part of the the token is in span, mark it as Toxic
|
82 |
+
# if (
|
83 |
+
# len(toxic_offsets) > 0
|
84 |
+
# and sum(toxic_offsets) / len(toxic_offsets) > 0.0
|
85 |
+
# ):
|
86 |
+
# labels[-1].append(1)
|
87 |
+
# else:
|
88 |
+
# labels[-1].append(0)
|
89 |
+
# prediction_mask[i][j] = 1
|
90 |
+
|
91 |
+
b_off = [x in Bs for x in range(offsets[0], offsets[1])]
|
92 |
+
b_off = sum(b_off)
|
93 |
+
i_off = [x in Is for x in range(offsets[0], offsets[1])]
|
94 |
+
i_off = sum(i_off)
|
95 |
+
# if len(b_off) == len(i_off) and len(i_off) == 0:
|
96 |
+
if b_off == 0 and i_off == 0:
|
97 |
+
labels[-1].append(0)
|
98 |
+
# elif len(b_off) >= len(i_off) == 1:
|
99 |
+
elif b_off >= i_off:
|
100 |
+
labels[-1].append(1)
|
101 |
+
# print(b_off)
|
102 |
+
# print(i_off)
|
103 |
+
# print(j)
|
104 |
+
else:
|
105 |
+
labels[-1].append(2)
|
106 |
+
|
107 |
+
tokenized_inputs["labels"] = labels
|
108 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
109 |
+
return tokenized_inputs
|
110 |
+
|
111 |
+
def tokenize_for_test(self, examples):
|
112 |
+
tokenized_inputs = self.tokenizer(
|
113 |
+
examples["text"], **self.config.tokenizer_params
|
114 |
+
)
|
115 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
116 |
+
labels = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
117 |
+
|
118 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
119 |
+
|
120 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
121 |
+
for j, offsets in enumerate(offset_mapping):
|
122 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
123 |
+
self.tokenizer.sep_token_id,
|
124 |
+
self.tokenizer.pad_token_id,
|
125 |
+
]:
|
126 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
127 |
+
else:
|
128 |
+
prediction_mask[i][j] = 1
|
129 |
+
|
130 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
131 |
+
tokenized_inputs["labels"] = labels
|
132 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_crf_tokens.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from datasets import load_dataset
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
@configmapper.map("datasets", "toxic_spans_crf_tokens")
|
8 |
+
class ToxicSpansCRFTokenDataset:
|
9 |
+
def __init__(self, config):
|
10 |
+
self.config = config
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
12 |
+
self.config.model_checkpoint_name
|
13 |
+
)
|
14 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
15 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
16 |
+
|
17 |
+
self.tokenized_inputs = self.dataset.map(
|
18 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
19 |
+
)
|
20 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
21 |
+
self.tokenize_for_test, batched=True
|
22 |
+
)
|
23 |
+
|
24 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
25 |
+
tokenized_inputs = self.tokenizer(
|
26 |
+
examples["text"], **self.config.tokenizer_params
|
27 |
+
)
|
28 |
+
|
29 |
+
# tokenized_inputs["text"] = examples["text"]
|
30 |
+
example_spans = []
|
31 |
+
labels = []
|
32 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
33 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
34 |
+
|
35 |
+
## Wrong Code
|
36 |
+
# for i, offset_mapping in enumerate(offsets_mapping):
|
37 |
+
# j = 0
|
38 |
+
# while j < len(offset_mapping): # [tok1, tok2, tok3] [(0,5),(1,4),(5,7)]
|
39 |
+
# if tokenized_inputs["input_ids"][i][j] in [
|
40 |
+
# self.tokenizer.sep_token_id,
|
41 |
+
# self.tokenizer.pad_token_id,
|
42 |
+
# self.tokenizer.cls_token_id,
|
43 |
+
# ]:
|
44 |
+
# j = j + 1
|
45 |
+
# continue
|
46 |
+
# else:
|
47 |
+
# k = j + 1
|
48 |
+
# while self.tokenizer.convert_ids_to_tokens(
|
49 |
+
# tokenized_inputs["input_ids"][i][k]
|
50 |
+
# ).startswith("##"):
|
51 |
+
# offset_mapping[i][j][1] = offset_mapping[i][k][1]
|
52 |
+
# j = k
|
53 |
+
|
54 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
55 |
+
labels.append([])
|
56 |
+
|
57 |
+
spans = eval(examples["spans"][i])
|
58 |
+
example_spans.append(spans)
|
59 |
+
cls_label = 2 ## DUMMY LABEL
|
60 |
+
for j, offsets in enumerate(offset_mapping):
|
61 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
62 |
+
self.tokenizer.sep_token_id,
|
63 |
+
self.tokenizer.pad_token_id,
|
64 |
+
]:
|
65 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
66 |
+
|
67 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
68 |
+
labels[-1].append(cls_label)
|
69 |
+
prediction_mask[i][j] = 1
|
70 |
+
|
71 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0:
|
72 |
+
labels[-1].append(2) ## DUMMY
|
73 |
+
|
74 |
+
else:
|
75 |
+
toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
76 |
+
## If any part of the the token is in span, mark it as Toxic
|
77 |
+
if (
|
78 |
+
len(toxic_offsets) > 0
|
79 |
+
and sum(toxic_offsets) / len(toxic_offsets) > 0.0
|
80 |
+
):
|
81 |
+
labels[-1].append(1)
|
82 |
+
else:
|
83 |
+
labels[-1].append(0)
|
84 |
+
prediction_mask[i][j] = 1
|
85 |
+
|
86 |
+
tokenized_inputs["labels"] = labels
|
87 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
88 |
+
return tokenized_inputs
|
89 |
+
|
90 |
+
def tokenize_for_test(self, examples):
|
91 |
+
tokenized_inputs = self.tokenizer(
|
92 |
+
examples["text"], **self.config.tokenizer_params
|
93 |
+
)
|
94 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
95 |
+
labels = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
96 |
+
|
97 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
98 |
+
|
99 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
100 |
+
for j, offsets in enumerate(offset_mapping):
|
101 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
102 |
+
self.tokenizer.sep_token_id,
|
103 |
+
self.tokenizer.pad_token_id,
|
104 |
+
]:
|
105 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
106 |
+
else:
|
107 |
+
prediction_mask[i][j] = 1
|
108 |
+
|
109 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
110 |
+
tokenized_inputs["labels"] = labels
|
111 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_multi_spans.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import load_dataset, Dataset
|
5 |
+
from evaluation.fix_spans import _contiguous_ranges
|
6 |
+
|
7 |
+
|
8 |
+
@configmapper.map("datasets", "toxic_spans_multi_spans")
|
9 |
+
class ToxicSpansMultiSpansDataset:
|
10 |
+
def __init__(self, config):
|
11 |
+
self.config = config
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
self.config.model_checkpoint_name
|
14 |
+
)
|
15 |
+
|
16 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
17 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
18 |
+
|
19 |
+
temp_key_train = list(self.dataset.keys())[0]
|
20 |
+
self.intermediate_dataset = self.dataset.map(
|
21 |
+
self.create_train_features,
|
22 |
+
batched=True,
|
23 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
24 |
+
remove_columns=self.dataset[temp_key_train].column_names,
|
25 |
+
)
|
26 |
+
|
27 |
+
temp_key_test = list(self.test_dataset.keys())[0]
|
28 |
+
self.intermediate_test_dataset = self.test_dataset.map(
|
29 |
+
self.create_test_features,
|
30 |
+
batched=True,
|
31 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
32 |
+
remove_columns=self.test_dataset[temp_key_test].column_names,
|
33 |
+
)
|
34 |
+
|
35 |
+
self.tokenized_inputs = self.intermediate_dataset.map(
|
36 |
+
self.prepare_train_features,
|
37 |
+
batched=True,
|
38 |
+
remove_columns=self.intermediate_dataset[temp_key_train].column_names,
|
39 |
+
)
|
40 |
+
self.test_tokenized_inputs = self.intermediate_test_dataset.map(
|
41 |
+
self.prepare_test_features,
|
42 |
+
batched=True,
|
43 |
+
remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
|
44 |
+
)
|
45 |
+
|
46 |
+
def create_train_features(self, examples):
|
47 |
+
features = {
|
48 |
+
"context": [],
|
49 |
+
"id": [],
|
50 |
+
"question": [],
|
51 |
+
"title": [],
|
52 |
+
"start_positions": [],
|
53 |
+
"end_positions": [],
|
54 |
+
}
|
55 |
+
id = 0
|
56 |
+
# print(examples)
|
57 |
+
for row_number in range(len(examples["text"])):
|
58 |
+
context = examples["text"][row_number]
|
59 |
+
question = "offense"
|
60 |
+
title = context.split(" ")[0]
|
61 |
+
start_positions = []
|
62 |
+
end_positions = []
|
63 |
+
span = eval(examples["spans"][row_number])
|
64 |
+
contiguous_spans = _contiguous_ranges(span)
|
65 |
+
for lst in contiguous_spans:
|
66 |
+
lst = list(lst)
|
67 |
+
dict_to_write = {}
|
68 |
+
|
69 |
+
start_positions.append(lst[0])
|
70 |
+
end_positions.append(lst[1])
|
71 |
+
|
72 |
+
features["context"].append(context)
|
73 |
+
features["id"].append(str(id))
|
74 |
+
features["question"].append(question)
|
75 |
+
features["title"].append(title)
|
76 |
+
features["start_positions"].append(start_positions)
|
77 |
+
features["end_positions"].append(end_positions)
|
78 |
+
id += 1
|
79 |
+
|
80 |
+
return features
|
81 |
+
|
82 |
+
def create_test_features(self, examples):
|
83 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
84 |
+
id = 0
|
85 |
+
for row_number in range(len(examples["text"])):
|
86 |
+
context = examples["text"][row_number]
|
87 |
+
question = "offense"
|
88 |
+
title = context.split(" ")[0]
|
89 |
+
features["context"].append(context)
|
90 |
+
features["id"].append(str(id))
|
91 |
+
features["question"].append(question)
|
92 |
+
features["title"].append(title)
|
93 |
+
id += 1
|
94 |
+
return features
|
95 |
+
|
96 |
+
def prepare_train_features(self, examples):
|
97 |
+
"""Generate tokenized features from examples.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
examples (dict): The examples to be tokenized.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
104 |
+
The tokenized features/examples after processing.
|
105 |
+
"""
|
106 |
+
# Tokenize our examples with truncation and padding, but keep the
|
107 |
+
# overflows using a stride. This results in one example possible
|
108 |
+
# giving several features when a context is long, each of those
|
109 |
+
# features having a context that overlaps a bit the context
|
110 |
+
# of the previous feature.
|
111 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
112 |
+
print("### Batch Tokenizing Examples ###")
|
113 |
+
tokenized_examples = self.tokenizer(
|
114 |
+
examples["question" if pad_on_right else "context"],
|
115 |
+
examples["context" if pad_on_right else "question"],
|
116 |
+
**dict(self.config.tokenizer_params),
|
117 |
+
)
|
118 |
+
|
119 |
+
# Since one example might give us several features if it has
|
120 |
+
# a long context, we need a map from a feature to
|
121 |
+
# its corresponding example. This key gives us just that.
|
122 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
123 |
+
# The offset mappings will give us a map from token to
|
124 |
+
# character position in the original context. This will
|
125 |
+
# help us compute the start_positions and end_positions.
|
126 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
127 |
+
|
128 |
+
# Let's label those examples!
|
129 |
+
tokenized_examples["start_positions"] = []
|
130 |
+
tokenized_examples["end_positions"] = []
|
131 |
+
|
132 |
+
for i, offsets in enumerate(offset_mapping):
|
133 |
+
# We will label impossible answers with the index of the CLS token.
|
134 |
+
input_ids = tokenized_examples["input_ids"][i]
|
135 |
+
|
136 |
+
# Grab the sequence corresponding to that example
|
137 |
+
# (to know what is the context and what is the question).
|
138 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
139 |
+
|
140 |
+
# One example can give several spans, this is the index of
|
141 |
+
# the example containing this span of text.
|
142 |
+
sample_index = sample_mapping[i]
|
143 |
+
start_positions = examples["start_positions"][sample_index]
|
144 |
+
end_positions = examples["end_positions"][sample_index]
|
145 |
+
|
146 |
+
start_positions_token_wise = [0 for x in range(len(input_ids))]
|
147 |
+
end_positions_token_wise = [0 for x in range(len(input_ids))]
|
148 |
+
# If no answers are given, set the cls_index as answer.
|
149 |
+
if len(start_positions) != 0:
|
150 |
+
for position in range(len(start_positions)):
|
151 |
+
start_char = start_positions[position]
|
152 |
+
end_char = end_positions[position] + 1
|
153 |
+
|
154 |
+
# Start token index of the current span in the text.
|
155 |
+
token_start_index = 0
|
156 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
157 |
+
token_start_index += 1
|
158 |
+
|
159 |
+
# End token index of the current span in the text.
|
160 |
+
token_end_index = len(input_ids) - 1
|
161 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
162 |
+
token_end_index -= 1
|
163 |
+
|
164 |
+
# Detect if the answer is out of the span (in which case we continue).
|
165 |
+
if not (
|
166 |
+
offsets[token_start_index][0] <= start_char
|
167 |
+
and offsets[token_end_index][1] >= end_char
|
168 |
+
):
|
169 |
+
continue
|
170 |
+
else:
|
171 |
+
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
|
172 |
+
# Note: we could go after the last offset if the answer is the last word (edge case).
|
173 |
+
while (
|
174 |
+
token_start_index < len(offsets)
|
175 |
+
and offsets[token_start_index][0] <= start_char
|
176 |
+
):
|
177 |
+
token_start_index += 1
|
178 |
+
start_positions_token_wise[token_start_index - 1] = 1
|
179 |
+
while offsets[token_end_index][1] >= end_char:
|
180 |
+
token_end_index -= 1
|
181 |
+
end_positions_token_wise[token_end_index + 1] = 1
|
182 |
+
tokenized_examples["start_positions"].append(start_positions_token_wise)
|
183 |
+
tokenized_examples["end_positions"].append(start_positions_token_wise)
|
184 |
+
return tokenized_examples
|
185 |
+
|
186 |
+
def prepare_test_features(self, examples):
|
187 |
+
|
188 |
+
"""Generate tokenized validation features from examples.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
examples (dict): The validation examples to be tokenized.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
195 |
+
The tokenized features/examples for validation set after processing.
|
196 |
+
"""
|
197 |
+
|
198 |
+
# Tokenize our examples with truncation and maybe
|
199 |
+
# padding, but keep the overflows using a stride.
|
200 |
+
# This results in one example possible giving several features
|
201 |
+
# when a context is long, each of those features having a
|
202 |
+
# context that overlaps a bit the context of the previous feature.
|
203 |
+
print("### Tokenizing Validation Examples")
|
204 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
205 |
+
tokenized_examples = self.tokenizer(
|
206 |
+
examples["question" if pad_on_right else "context"],
|
207 |
+
examples["context" if pad_on_right else "question"],
|
208 |
+
**dict(self.config.tokenizer_params),
|
209 |
+
)
|
210 |
+
|
211 |
+
# Since one example might give us several features if it has a long context,
|
212 |
+
# we need a map from a feature to its corresponding example. This key gives us just that.
|
213 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
214 |
+
|
215 |
+
# We keep the example_id that gave us this feature and we will store the offset mappings.
|
216 |
+
tokenized_examples["example_id"] = []
|
217 |
+
|
218 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
219 |
+
# Grab the sequence corresponding to that example
|
220 |
+
# (to know what is the context and what is the question).
|
221 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
222 |
+
context_index = 1 if pad_on_right else 0
|
223 |
+
|
224 |
+
# One example can give several spans,
|
225 |
+
# this is the index of the example containing this span of text.
|
226 |
+
sample_index = sample_mapping[i]
|
227 |
+
tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
|
228 |
+
|
229 |
+
# Set to None the offset_mapping that are not part
|
230 |
+
# of the context so it's easy to determine if a token
|
231 |
+
# position is part of the context or not.
|
232 |
+
tokenized_examples["offset_mapping"][i] = [
|
233 |
+
(o if sequence_ids[k] == context_index else None)
|
234 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
235 |
+
]
|
236 |
+
|
237 |
+
return tokenized_examples
|
src/datasets/toxic_spans_spans.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import load_dataset, Dataset
|
5 |
+
from evaluation.fix_spans import _contiguous_ranges
|
6 |
+
|
7 |
+
|
8 |
+
@configmapper.map("datasets", "toxic_spans_spans")
|
9 |
+
class ToxicSpansSpansDataset:
|
10 |
+
def __init__(self, config):
|
11 |
+
# print("### ToxicSpansSpansDataset ###"); exit()
|
12 |
+
self.config = config
|
13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
14 |
+
self.config.model_checkpoint_name
|
15 |
+
)
|
16 |
+
|
17 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
18 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
19 |
+
|
20 |
+
temp_key_train = list(self.dataset.keys())[0]
|
21 |
+
self.intermediate_dataset = self.dataset.map(
|
22 |
+
self.create_train_features,
|
23 |
+
batched=True,
|
24 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
25 |
+
remove_columns=self.dataset[temp_key_train].column_names,
|
26 |
+
)
|
27 |
+
|
28 |
+
temp_key_test = list(self.test_dataset.keys())[0]
|
29 |
+
self.intermediate_test_dataset = self.test_dataset.map(
|
30 |
+
self.create_test_features,
|
31 |
+
batched=True,
|
32 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
33 |
+
remove_columns=self.test_dataset[temp_key_test].column_names,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.tokenized_inputs = self.intermediate_dataset.map(
|
37 |
+
self.prepare_train_features,
|
38 |
+
batched=True,
|
39 |
+
remove_columns=self.intermediate_dataset[temp_key_train].column_names,
|
40 |
+
)
|
41 |
+
self.test_tokenized_inputs = self.intermediate_test_dataset.map(
|
42 |
+
self.prepare_test_features,
|
43 |
+
batched=True,
|
44 |
+
remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
|
45 |
+
)
|
46 |
+
|
47 |
+
def create_train_features(self, examples):
|
48 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
49 |
+
id = 0
|
50 |
+
# print(examples)
|
51 |
+
for row_number in range(len(examples["text"])):
|
52 |
+
context = examples["text"][row_number]
|
53 |
+
# question = "offense"
|
54 |
+
question = "ভুল"
|
55 |
+
title = context.split(" ")[0]
|
56 |
+
span = eval(examples["spans"][row_number])
|
57 |
+
contiguous_spans = _contiguous_ranges(span)
|
58 |
+
for lst in contiguous_spans:
|
59 |
+
lst = list(lst)
|
60 |
+
dict_to_write = {}
|
61 |
+
|
62 |
+
dict_to_write["answer_start"] = [lst[0]]
|
63 |
+
dict_to_write["text"] = [context[lst[0] : lst[-1] + 1]]
|
64 |
+
# print(dict_to_write)
|
65 |
+
if "answers" in features.keys():
|
66 |
+
features["answers"].append(dict_to_write)
|
67 |
+
else:
|
68 |
+
features["answers"] = [
|
69 |
+
dict_to_write,
|
70 |
+
]
|
71 |
+
features["context"].append(context)
|
72 |
+
features["id"].append(str(id))
|
73 |
+
features["question"].append(question)
|
74 |
+
features["title"].append(title)
|
75 |
+
id += 1
|
76 |
+
|
77 |
+
return features
|
78 |
+
|
79 |
+
def create_test_features(self, examples):
|
80 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
81 |
+
id = 0
|
82 |
+
for row_number in range(len(examples["text"])):
|
83 |
+
context = examples["text"][row_number]
|
84 |
+
# question = "offense"
|
85 |
+
question = "ভুল"
|
86 |
+
title = context.split(" ")[0]
|
87 |
+
features["context"].append(context)
|
88 |
+
features["id"].append(str(id))
|
89 |
+
features["question"].append(question)
|
90 |
+
features["title"].append(title)
|
91 |
+
id += 1
|
92 |
+
return features
|
93 |
+
|
94 |
+
def prepare_train_features(self, examples):
|
95 |
+
"""Generate tokenized features from examples.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
examples (dict): The examples to be tokenized.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
102 |
+
The tokenized features/examples after processing.
|
103 |
+
"""
|
104 |
+
# Tokenize our examples with truncation and padding, but keep the
|
105 |
+
# overflows using a stride. This results in one example possible
|
106 |
+
# giving several features when a context is long, each of those
|
107 |
+
# features having a context that overlaps a bit the context
|
108 |
+
# of the previous feature.
|
109 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
110 |
+
print("### Batch Tokenizing Examples ###")
|
111 |
+
tokenized_examples = self.tokenizer(
|
112 |
+
examples["question" if pad_on_right else "context"],
|
113 |
+
examples["context" if pad_on_right else "question"],
|
114 |
+
**dict(self.config.tokenizer_params),
|
115 |
+
)
|
116 |
+
|
117 |
+
# Since one example might give us several features if it has
|
118 |
+
# a long context, we need a map from a feature to
|
119 |
+
# its corresponding example. This key gives us just that.
|
120 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
121 |
+
# The offset mappings will give us a map from token to
|
122 |
+
# character position in the original context. This will
|
123 |
+
# help us compute the start_positions and end_positions.
|
124 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
125 |
+
|
126 |
+
# Let's label those examples!
|
127 |
+
tokenized_examples["start_positions"] = []
|
128 |
+
tokenized_examples["end_positions"] = []
|
129 |
+
|
130 |
+
for i, offsets in enumerate(offset_mapping):
|
131 |
+
# We will label impossible answers with the index of the CLS token.
|
132 |
+
input_ids = tokenized_examples["input_ids"][i]
|
133 |
+
cls_index = input_ids.index(self.tokenizer.cls_token_id)
|
134 |
+
|
135 |
+
# Grab the sequence corresponding to that example
|
136 |
+
# (to know what is the context and what is the question).
|
137 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
138 |
+
|
139 |
+
# One example can give several spans, this is the index of
|
140 |
+
# the example containing this span of text.
|
141 |
+
sample_index = sample_mapping[i]
|
142 |
+
answers = examples["answers"][sample_index]
|
143 |
+
# If no answers are given, set the cls_index as answer.
|
144 |
+
if len(answers["answer_start"]) == 0:
|
145 |
+
tokenized_examples["start_positions"].append(cls_index)
|
146 |
+
tokenized_examples["end_positions"].append(cls_index)
|
147 |
+
else:
|
148 |
+
# Start/end character index of the answer in the text.
|
149 |
+
start_char = answers["answer_start"][0]
|
150 |
+
end_char = start_char + len(answers["text"][0])
|
151 |
+
|
152 |
+
# Start token index of the current span in the text.
|
153 |
+
token_start_index = 0
|
154 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
155 |
+
token_start_index += 1
|
156 |
+
|
157 |
+
# End token index of the current span in the text.
|
158 |
+
token_end_index = len(input_ids) - 1
|
159 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
160 |
+
token_end_index -= 1
|
161 |
+
|
162 |
+
# Detect if the answer is out of the span
|
163 |
+
# (in which case this feature is labeled with the CLS index).
|
164 |
+
if not (
|
165 |
+
offsets[token_start_index][0] <= start_char
|
166 |
+
and offsets[token_end_index][1] >= end_char
|
167 |
+
):
|
168 |
+
tokenized_examples["start_positions"].append(cls_index)
|
169 |
+
tokenized_examples["end_positions"].append(cls_index)
|
170 |
+
else:
|
171 |
+
# Otherwise move the token_start_index and
|
172 |
+
# stoken_end_index to the two ends of the answer.
|
173 |
+
# Note: we could go after the last offset
|
174 |
+
# if the answer is the last word (edge case).
|
175 |
+
while (
|
176 |
+
token_start_index < len(offsets)
|
177 |
+
and offsets[token_start_index][0] <= start_char
|
178 |
+
):
|
179 |
+
token_start_index += 1
|
180 |
+
tokenized_examples["start_positions"].append(token_start_index - 1)
|
181 |
+
while offsets[token_end_index][1] >= end_char:
|
182 |
+
token_end_index -= 1
|
183 |
+
tokenized_examples["end_positions"].append(token_end_index + 1)
|
184 |
+
|
185 |
+
return tokenized_examples
|
186 |
+
|
187 |
+
def prepare_test_features(self, examples):
|
188 |
+
|
189 |
+
"""Generate tokenized validation features from examples.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
examples (dict): The validation examples to be tokenized.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
196 |
+
The tokenized features/examples for validation set after processing.
|
197 |
+
"""
|
198 |
+
|
199 |
+
# Tokenize our examples with truncation and maybe
|
200 |
+
# padding, but keep the overflows using a stride.
|
201 |
+
# This results in one example possible giving several features
|
202 |
+
# when a context is long, each of those features having a
|
203 |
+
# context that overlaps a bit the context of the previous feature.
|
204 |
+
print("### Tokenizing Validation Examples")
|
205 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
206 |
+
tokenized_examples = self.tokenizer(
|
207 |
+
examples["question" if pad_on_right else "context"],
|
208 |
+
examples["context" if pad_on_right else "question"],
|
209 |
+
**dict(self.config.tokenizer_params),
|
210 |
+
)
|
211 |
+
|
212 |
+
# Since one example might give us several features if it has a long context,
|
213 |
+
# we need a map from a feature to its corresponding example. This key gives us just that.
|
214 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
215 |
+
|
216 |
+
# We keep the example_id that gave us this feature and we will store the offset mappings.
|
217 |
+
tokenized_examples["example_id"] = []
|
218 |
+
|
219 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
220 |
+
# Grab the sequence corresponding to that example
|
221 |
+
# (to know what is the context and what is the question).
|
222 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
223 |
+
context_index = 1 if pad_on_right else 0
|
224 |
+
|
225 |
+
# One example can give several spans,
|
226 |
+
# this is the index of the example containing this span of text.
|
227 |
+
sample_index = sample_mapping[i]
|
228 |
+
tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
|
229 |
+
|
230 |
+
# Set to None the offset_mapping that are not part
|
231 |
+
# of the context so it's easy to determine if a token
|
232 |
+
# position is part of the context or not.
|
233 |
+
tokenized_examples["offset_mapping"][i] = [
|
234 |
+
(o if sequence_ids[k] == context_index else None)
|
235 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
236 |
+
]
|
237 |
+
|
238 |
+
return tokenized_examples
|
src/datasets/toxic_spans_tokens.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from datasets import load_dataset
|
4 |
+
|
5 |
+
# import pdb
|
6 |
+
|
7 |
+
@configmapper.map("datasets", "toxic_spans_tokens")
|
8 |
+
class ToxicSpansTokenDataset:
|
9 |
+
def __init__(self, config):
|
10 |
+
# print("### ToxicSpansTokenDataset ###"); exit()
|
11 |
+
self.config = config
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
self.config.model_checkpoint_name
|
14 |
+
)
|
15 |
+
# if self.config.model_checkpoint_name == "sberbank-ai/mGPT":
|
16 |
+
# self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
17 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
18 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
19 |
+
|
20 |
+
self.tokenized_inputs = self.dataset.map(
|
21 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
22 |
+
)
|
23 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
24 |
+
self.tokenize_for_test, batched=True
|
25 |
+
)
|
26 |
+
|
27 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
28 |
+
|
29 |
+
tokenized_inputs = self.tokenizer(
|
30 |
+
examples["text"], **self.config.tokenizer_params
|
31 |
+
)
|
32 |
+
|
33 |
+
# tokenized_inputs["text"] = examples["text"]
|
34 |
+
example_spans = []
|
35 |
+
labels = []
|
36 |
+
|
37 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
38 |
+
# pdb.set_trace()
|
39 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
40 |
+
labels.append([])
|
41 |
+
|
42 |
+
spans = eval(examples["spans"][i])
|
43 |
+
example_spans.append(spans)
|
44 |
+
if self.config.label_cls:
|
45 |
+
cls_label = (
|
46 |
+
1
|
47 |
+
if (
|
48 |
+
len(examples["text"][i]) > 0
|
49 |
+
and len(spans) / len(examples["text"][i])
|
50 |
+
> self.config.cls_threshold
|
51 |
+
)
|
52 |
+
else 0
|
53 |
+
) ## Make class label based on threshold
|
54 |
+
else:
|
55 |
+
cls_label = -100
|
56 |
+
for j, offsets in enumerate(offset_mapping):
|
57 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
58 |
+
labels[-1].append(cls_label)
|
59 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0: # All zero
|
60 |
+
labels[-1].append(-100) ## SPECIAL TOKEN
|
61 |
+
else:
|
62 |
+
toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
63 |
+
## If any part of the the token is in span, mark it as Toxic
|
64 |
+
if (
|
65 |
+
len(toxic_offsets) > 0
|
66 |
+
and sum(toxic_offsets) / len(toxic_offsets)
|
67 |
+
> self.config.token_threshold
|
68 |
+
):
|
69 |
+
labels[-1].append(1)
|
70 |
+
else:
|
71 |
+
labels[-1].append(0)
|
72 |
+
|
73 |
+
tokenized_inputs["labels"] = labels
|
74 |
+
# print("tokenized_inputs", tokenized_inputs); exit()
|
75 |
+
return tokenized_inputs
|
76 |
+
|
77 |
+
def tokenize_for_test(self, examples):
|
78 |
+
tokenized_inputs = self.tokenizer(
|
79 |
+
examples["text"], **self.config.tokenizer_params
|
80 |
+
)
|
81 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_tokens_3cls.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from datasets import load_dataset
|
4 |
+
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
@configmapper.map("datasets", "toxic_spans_tokens_3cls")
|
8 |
+
class ToxicSpansToken3CLSDataset:
|
9 |
+
def __init__(self, config):
|
10 |
+
# print("### ToxicSpansTokenDataset ###"); exit()
|
11 |
+
self.config = config
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
self.config.model_checkpoint_name
|
14 |
+
)
|
15 |
+
# if self.config.model_checkpoint_name == "sberbank-ai/mGPT":
|
16 |
+
# self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
17 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
18 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
19 |
+
|
20 |
+
self.tokenized_inputs = self.dataset.map(
|
21 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
22 |
+
)
|
23 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
24 |
+
self.tokenize_for_test, batched=True
|
25 |
+
)
|
26 |
+
|
27 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
28 |
+
|
29 |
+
tokenized_inputs = self.tokenizer(
|
30 |
+
examples["text"], **self.config.tokenizer_params
|
31 |
+
)
|
32 |
+
|
33 |
+
# tokenized_inputs["text"] = examples["text"]
|
34 |
+
example_spans = []
|
35 |
+
labels = []
|
36 |
+
|
37 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
38 |
+
# pdb.set_trace()
|
39 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
40 |
+
labels.append([])
|
41 |
+
|
42 |
+
spans = eval(examples["spans"][i])
|
43 |
+
Bs = eval(examples["Bs"][i])
|
44 |
+
Is = eval(examples["Is"][i])
|
45 |
+
example_spans.append(spans)
|
46 |
+
if self.config.label_cls:
|
47 |
+
cls_label = (
|
48 |
+
1
|
49 |
+
if (
|
50 |
+
len(examples["text"][i]) > 0
|
51 |
+
and len(spans) / len(examples["text"][i])
|
52 |
+
> self.config.cls_threshold
|
53 |
+
)
|
54 |
+
else 0
|
55 |
+
) ## Make class label based on threshold
|
56 |
+
else:
|
57 |
+
cls_label = -100
|
58 |
+
for j, offsets in enumerate(offset_mapping):
|
59 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
60 |
+
labels[-1].append(cls_label)
|
61 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0: # All zero
|
62 |
+
labels[-1].append(-100) ## SPECIAL TOKEN
|
63 |
+
else:
|
64 |
+
# toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
65 |
+
## If any part of the the token is in span, mark it as Toxic
|
66 |
+
# if (
|
67 |
+
# len(toxic_offsets) > 0
|
68 |
+
# and sum(toxic_offsets) / len(toxic_offsets)
|
69 |
+
# > self.config.token_threshold
|
70 |
+
# ):
|
71 |
+
# labels[-1].append(1)
|
72 |
+
# else:
|
73 |
+
# labels[-1].append(0)
|
74 |
+
b_off = [x in Bs for x in range(offsets[0], offsets[1])]
|
75 |
+
b_off = sum(b_off)
|
76 |
+
i_off = [x in Is for x in range(offsets[0], offsets[1])]
|
77 |
+
i_off = sum(i_off)
|
78 |
+
# if len(b_off) == len(i_off) and len(i_off) == 0:
|
79 |
+
if b_off == 0 and i_off == 0:
|
80 |
+
labels[-1].append(0)
|
81 |
+
# elif len(b_off) >= len(i_off) == 1:
|
82 |
+
elif b_off >= i_off:
|
83 |
+
labels[-1].append(1)
|
84 |
+
# print(b_off)
|
85 |
+
# print(i_off)
|
86 |
+
# print(j)
|
87 |
+
else:
|
88 |
+
labels[-1].append(2)
|
89 |
+
|
90 |
+
# pdb.set_trace()
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
tokenized_inputs["labels"] = labels
|
95 |
+
# print("tokenized_inputs", tokenized_inputs); exit()
|
96 |
+
return tokenized_inputs
|
97 |
+
|
98 |
+
def tokenize_for_test(self, examples):
|
99 |
+
tokenized_inputs = self.tokenizer(
|
100 |
+
examples["text"], **self.config.tokenizer_params
|
101 |
+
)
|
102 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_tokens_spans.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.mapper import configmapper
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import load_dataset, Dataset
|
5 |
+
from evaluation.fix_spans import _contiguous_ranges
|
6 |
+
|
7 |
+
|
8 |
+
@configmapper.map("datasets", "toxic_spans_tokens_spans")
|
9 |
+
class ToxicSpansTokensSpansDataset:
|
10 |
+
def __init__(self, config):
|
11 |
+
self.config = config
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
self.config.model_checkpoint_name
|
14 |
+
)
|
15 |
+
|
16 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
17 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
18 |
+
|
19 |
+
temp_key_train = list(self.dataset.keys())[0]
|
20 |
+
self.intermediate_dataset = self.dataset.map(
|
21 |
+
self.create_train_features,
|
22 |
+
batched=True,
|
23 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
24 |
+
remove_columns=self.dataset[temp_key_train].column_names,
|
25 |
+
)
|
26 |
+
|
27 |
+
temp_key_test = list(self.test_dataset.keys())[0]
|
28 |
+
self.intermediate_test_dataset = self.test_dataset.map(
|
29 |
+
self.create_test_features,
|
30 |
+
batched=True,
|
31 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
32 |
+
remove_columns=self.test_dataset[temp_key_test].column_names,
|
33 |
+
)
|
34 |
+
|
35 |
+
self.tokenized_inputs = self.intermediate_dataset.map(
|
36 |
+
self.prepare_train_features,
|
37 |
+
batched=True,
|
38 |
+
remove_columns=self.intermediate_dataset[temp_key_train].column_names,
|
39 |
+
)
|
40 |
+
self.test_tokenized_inputs = self.intermediate_test_dataset.map(
|
41 |
+
self.prepare_test_features,
|
42 |
+
batched=True,
|
43 |
+
remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
|
44 |
+
)
|
45 |
+
|
46 |
+
def create_train_features(self, examples):
|
47 |
+
features = {"context": [], "id": [], "question": [], "title": [], "spans": []}
|
48 |
+
id = 0
|
49 |
+
# print(examples)
|
50 |
+
for row_number in range(len(examples["text"])):
|
51 |
+
context = examples["text"][row_number]
|
52 |
+
question = "offense"
|
53 |
+
title = context.split(" ")[0]
|
54 |
+
span = eval(examples["spans"][row_number])
|
55 |
+
contiguous_spans = _contiguous_ranges(span)
|
56 |
+
for lst in contiguous_spans:
|
57 |
+
lst = list(lst)
|
58 |
+
dict_to_write = {}
|
59 |
+
|
60 |
+
dict_to_write["answer_start"] = [lst[0]]
|
61 |
+
dict_to_write["text"] = [context[lst[0] : lst[-1] + 1]]
|
62 |
+
# print(dict_to_write)
|
63 |
+
if "answers" in features.keys():
|
64 |
+
features["answers"].append(dict_to_write)
|
65 |
+
else:
|
66 |
+
features["answers"] = [
|
67 |
+
dict_to_write,
|
68 |
+
]
|
69 |
+
features["context"].append(context)
|
70 |
+
features["id"].append(str(id))
|
71 |
+
features["question"].append(question)
|
72 |
+
features["title"].append(title)
|
73 |
+
features["spans"].append(span)
|
74 |
+
id += 1
|
75 |
+
|
76 |
+
return features
|
77 |
+
|
78 |
+
def create_test_features(self, examples):
|
79 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
80 |
+
id = 0
|
81 |
+
for row_number in range(len(examples["text"])):
|
82 |
+
context = examples["text"][row_number]
|
83 |
+
question = "offense"
|
84 |
+
title = context.split(" ")[0]
|
85 |
+
features["context"].append(context)
|
86 |
+
features["id"].append(str(id))
|
87 |
+
features["question"].append(question)
|
88 |
+
features["title"].append(title)
|
89 |
+
id += 1
|
90 |
+
return features
|
91 |
+
|
92 |
+
def prepare_train_features(self, examples):
|
93 |
+
"""Generate tokenized features from examples.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
examples (dict): The examples to be tokenized.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
100 |
+
The tokenized features/examples after processing.
|
101 |
+
"""
|
102 |
+
# Tokenize our examples with truncation and padding, but keep the
|
103 |
+
# overflows using a stride. This results in one example possible
|
104 |
+
# giving several features when a context is long, each of those
|
105 |
+
# features having a context that overlaps a bit the context
|
106 |
+
# of the previous feature.
|
107 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
108 |
+
print("### Batch Tokenizing Examples ###")
|
109 |
+
tokenized_examples = self.tokenizer(
|
110 |
+
examples["question" if pad_on_right else "context"],
|
111 |
+
examples["context" if pad_on_right else "question"],
|
112 |
+
**dict(self.config.tokenizer_params),
|
113 |
+
)
|
114 |
+
|
115 |
+
# Since one example might give us several features if it has
|
116 |
+
# a long context, we need a map from a feature to
|
117 |
+
# its corresponding example. This key gives us just that.
|
118 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
119 |
+
# The offset mappings will give us a map from token to
|
120 |
+
# character position in the original context. This will
|
121 |
+
# help us compute the start_positions and end_positions.
|
122 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
123 |
+
|
124 |
+
# Let's label those examples!
|
125 |
+
token_labels = []
|
126 |
+
tokenized_examples["start_positions"] = []
|
127 |
+
tokenized_examples["end_positions"] = []
|
128 |
+
|
129 |
+
for i, offsets in enumerate(offset_mapping):
|
130 |
+
# We will label impossible answers with the index of the CLS token.
|
131 |
+
|
132 |
+
token_labels.append([])
|
133 |
+
input_ids = tokenized_examples["input_ids"][i]
|
134 |
+
spans = examples["spans"][i]
|
135 |
+
if self.config.label_cls:
|
136 |
+
cls_label = (
|
137 |
+
1
|
138 |
+
if (
|
139 |
+
len(examples["context"][i]) > 0
|
140 |
+
and len(spans) / len(examples["context"][i])
|
141 |
+
> self.config.cls_threshold
|
142 |
+
)
|
143 |
+
else 0
|
144 |
+
) ## Make class label based on threshold
|
145 |
+
else:
|
146 |
+
cls_label = -100
|
147 |
+
for j, offset in enumerate(offsets):
|
148 |
+
if tokenized_examples["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
149 |
+
token_labels[-1].append(cls_label)
|
150 |
+
elif offset[0] == offset[1] and offset[0] == 0:
|
151 |
+
token_labels[-1].append(-100) ## SPECIAL TOKEN
|
152 |
+
else:
|
153 |
+
toxic_offsets = [x in spans for x in range(offset[0], offset[1])]
|
154 |
+
## If any part of the the token is in span, mark it as Toxic
|
155 |
+
if (
|
156 |
+
len(toxic_offsets) > 0
|
157 |
+
and sum(toxic_offsets) / len(toxic_offsets)
|
158 |
+
> self.config.token_threshold
|
159 |
+
):
|
160 |
+
token_labels[-1].append(1)
|
161 |
+
else:
|
162 |
+
token_labels[-1].append(0)
|
163 |
+
|
164 |
+
cls_index = input_ids.index(self.tokenizer.cls_token_id)
|
165 |
+
|
166 |
+
# Grab the sequence corresponding to that example
|
167 |
+
# (to know what is the context and what is the question).
|
168 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
169 |
+
|
170 |
+
# One example can give several spans, this is the index of
|
171 |
+
# the example containing this span of text.
|
172 |
+
sample_index = sample_mapping[i]
|
173 |
+
answers = examples["answers"][sample_index]
|
174 |
+
# If no answers are given, set the cls_index as answer.
|
175 |
+
if len(answers["answer_start"]) == 0:
|
176 |
+
tokenized_examples["start_positions"].append(cls_index)
|
177 |
+
tokenized_examples["end_positions"].append(cls_index)
|
178 |
+
else:
|
179 |
+
# Start/end character index of the answer in the text.
|
180 |
+
start_char = answers["answer_start"][0]
|
181 |
+
end_char = start_char + len(answers["text"][0])
|
182 |
+
|
183 |
+
# Start token index of the current span in the text.
|
184 |
+
token_start_index = 0
|
185 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
186 |
+
token_start_index += 1
|
187 |
+
|
188 |
+
# End token index of the current span in the text.
|
189 |
+
token_end_index = len(input_ids) - 1
|
190 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
191 |
+
token_end_index -= 1
|
192 |
+
|
193 |
+
# Detect if the answer is out of the span
|
194 |
+
# (in which case this feature is labeled with the CLS index).
|
195 |
+
if not (
|
196 |
+
offsets[token_start_index][0] <= start_char
|
197 |
+
and offsets[token_end_index][1] >= end_char
|
198 |
+
):
|
199 |
+
tokenized_examples["start_positions"].append(cls_index)
|
200 |
+
tokenized_examples["end_positions"].append(cls_index)
|
201 |
+
else:
|
202 |
+
# Otherwise move the token_start_index and
|
203 |
+
# stoken_end_index to the two ends of the answer.
|
204 |
+
# Note: we could go after the last offset
|
205 |
+
# if the answer is the last word (edge case).
|
206 |
+
while (
|
207 |
+
token_start_index < len(offsets)
|
208 |
+
and offsets[token_start_index][0] <= start_char
|
209 |
+
):
|
210 |
+
token_start_index += 1
|
211 |
+
tokenized_examples["start_positions"].append(token_start_index - 1)
|
212 |
+
while offsets[token_end_index][1] >= end_char:
|
213 |
+
token_end_index -= 1
|
214 |
+
tokenized_examples["end_positions"].append(token_end_index + 1)
|
215 |
+
tokenized_examples["labels"] = token_labels
|
216 |
+
return tokenized_examples
|
217 |
+
|
218 |
+
def prepare_test_features(self, examples):
|
219 |
+
|
220 |
+
"""Generate tokenized validation features from examples.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
examples (dict): The validation examples to be tokenized.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
227 |
+
The tokenized features/examples for validation set after processing.
|
228 |
+
"""
|
229 |
+
|
230 |
+
# Tokenize our examples with truncation and maybe
|
231 |
+
# padding, but keep the overflows using a stride.
|
232 |
+
# This results in one example possible giving several features
|
233 |
+
# when a context is long, each of those features having a
|
234 |
+
# context that overlaps a bit the context of the previous feature.
|
235 |
+
print("### Tokenizing Validation Examples")
|
236 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
237 |
+
tokenized_examples = self.tokenizer(
|
238 |
+
examples["question" if pad_on_right else "context"],
|
239 |
+
examples["context" if pad_on_right else "question"],
|
240 |
+
**dict(self.config.tokenizer_params),
|
241 |
+
)
|
242 |
+
|
243 |
+
# Since one example might give us several features if it has a long context,
|
244 |
+
# we need a map from a feature to its corresponding example. This key gives us just that.
|
245 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
246 |
+
|
247 |
+
# We keep the example_id that gave us this feature and we will store the offset mappings.
|
248 |
+
tokenized_examples["example_id"] = []
|
249 |
+
|
250 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
251 |
+
# Grab the sequence corresponding to that example
|
252 |
+
# (to know what is the context and what is the question).
|
253 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
254 |
+
context_index = 1 if pad_on_right else 0
|
255 |
+
|
256 |
+
# One example can give several spans,
|
257 |
+
# this is the index of the example containing this span of text.
|
258 |
+
sample_index = sample_mapping[i]
|
259 |
+
tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
|
260 |
+
|
261 |
+
# Set to None the offset_mapping that are not part
|
262 |
+
# of the context so it's easy to determine if a token
|
263 |
+
# position is part of the context or not.
|
264 |
+
tokenized_examples["offset_mapping"][i] = [
|
265 |
+
(o if sequence_ids[k] == context_index else None)
|
266 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
267 |
+
]
|
268 |
+
|
269 |
+
return tokenized_examples
|
src/models/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.models.auto_models import *
|
2 |
+
from src.models.bert_token_spans import *
|
3 |
+
from src.models.roberta_token_spans import *
|
4 |
+
from src.models.bert_multi_spans import *
|
5 |
+
from src.models.roberta_multi_spans import *
|
6 |
+
from src.models.bert_crf_token import *
|
7 |
+
from src.models.roberta_crf_token import *
|
src/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (443 Bytes). View file
|
|
src/models/__pycache__/auto_models.cpython-38.pyc
ADDED
Binary file (436 Bytes). View file
|
|
src/models/__pycache__/bert_crf_token.cpython-38.pyc
ADDED
Binary file (1.68 kB). View file
|
|
src/models/__pycache__/bert_multi_spans.cpython-38.pyc
ADDED
Binary file (1.72 kB). View file
|
|
src/models/__pycache__/bert_token_spans.cpython-38.pyc
ADDED
Binary file (2.34 kB). View file
|
|
src/models/__pycache__/roberta_crf_token.cpython-38.pyc
ADDED
Binary file (1.69 kB). View file
|
|
src/models/__pycache__/roberta_multi_spans.cpython-38.pyc
ADDED
Binary file (1.79 kB). View file
|
|
src/models/__pycache__/roberta_token_spans.cpython-38.pyc
ADDED
Binary file (2.42 kB). View file
|
|
src/models/auto_models.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForTokenClassification, AutoModelForQuestionAnswering
|
2 |
+
from src.utils.mapper import configmapper
|
3 |
+
|
4 |
+
configmapper.map("models", "autotoken")(AutoModelForTokenClassification)
|
5 |
+
configmapper.map("models", "autotoken_3cls")(AutoModelForTokenClassification)
|
6 |
+
configmapper.map("models", "autospans")(AutoModelForQuestionAnswering)
|
src/models/bert_crf_token.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from transformers import BertForTokenClassification
|
3 |
+
from transformers import ElectraForTokenClassification
|
4 |
+
from torchcrf import CRF
|
5 |
+
from src.utils.mapper import configmapper
|
6 |
+
# import pdb
|
7 |
+
|
8 |
+
|
9 |
+
@configmapper.map("models", "bert_crf_token")
|
10 |
+
# class BertLSTMCRF(BertForTokenClassification):
|
11 |
+
class BertLSTMCRF(ElectraForTokenClassification):
|
12 |
+
def __init__(self, config, lstm_hidden_size, lstm_layers):
|
13 |
+
super().__init__(config)
|
14 |
+
# ipdb.set_trace()
|
15 |
+
self.lstm = torch.nn.LSTM(
|
16 |
+
input_size=config.hidden_size,
|
17 |
+
hidden_size=lstm_hidden_size,
|
18 |
+
num_layers=lstm_layers,
|
19 |
+
dropout=0.2,
|
20 |
+
batch_first=True,
|
21 |
+
bidirectional=True,
|
22 |
+
)
|
23 |
+
self.crf = CRF(config.num_labels, batch_first=True)
|
24 |
+
|
25 |
+
del self.classifier
|
26 |
+
self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self,
|
30 |
+
input_ids,
|
31 |
+
attention_mask=None,
|
32 |
+
token_type_ids=None,
|
33 |
+
labels=None,
|
34 |
+
prediction_mask=None,
|
35 |
+
):
|
36 |
+
# pdb.set_trace()
|
37 |
+
|
38 |
+
# outputs = self.bert(
|
39 |
+
outputs = self.electra(
|
40 |
+
input_ids,
|
41 |
+
attention_mask,
|
42 |
+
token_type_ids,
|
43 |
+
output_hidden_states=True,
|
44 |
+
return_dict=False,
|
45 |
+
)
|
46 |
+
# seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
|
47 |
+
|
48 |
+
sequence_output = outputs[0] # outputs[1] is pooled output which is none.
|
49 |
+
|
50 |
+
sequence_output = self.dropout(sequence_output)
|
51 |
+
|
52 |
+
lstm_out, *_ = self.lstm(sequence_output)
|
53 |
+
sequence_output = self.dropout(lstm_out)
|
54 |
+
|
55 |
+
logits = self.classifier(sequence_output)
|
56 |
+
|
57 |
+
## CRF
|
58 |
+
mask = prediction_mask
|
59 |
+
mask = mask[:, : logits.size(1)].contiguous()
|
60 |
+
|
61 |
+
# print(logits)
|
62 |
+
|
63 |
+
if labels is not None:
|
64 |
+
labels = labels[:, : logits.size(1)].contiguous()
|
65 |
+
loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
|
66 |
+
|
67 |
+
tags = self.crf.decode(logits, mask.bool())
|
68 |
+
# print(tags)
|
69 |
+
if labels is not None:
|
70 |
+
return (loss, logits, tags)
|
71 |
+
else:
|
72 |
+
return (logits, tags)
|
src/models/bert_multi_spans.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import BCEWithLogitsLoss
|
3 |
+
# from transformers import BertModel, BertPreTrainedModel
|
4 |
+
from transformers import ElectraPreTrainedModel, ElectraModel
|
5 |
+
from src.utils.mapper import configmapper
|
6 |
+
|
7 |
+
|
8 |
+
@configmapper.map("models", "bert_multi_spans")
|
9 |
+
# class BertForMultiSpans(BertPreTrainedModel):
|
10 |
+
class BertForMultiSpans(ElectraPreTrainedModel):
|
11 |
+
def __init__(self, config):
|
12 |
+
super(BertForMultiSpans, self).__init__(config)
|
13 |
+
# self.bert = BertModel(config)
|
14 |
+
self.bert = ElectraModel(config)
|
15 |
+
self.num_labels = config.num_labels
|
16 |
+
|
17 |
+
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
18 |
+
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
19 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
20 |
+
self.init_weights()
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
input_ids=None,
|
25 |
+
attention_mask=None,
|
26 |
+
token_type_ids=None,
|
27 |
+
position_ids=None,
|
28 |
+
head_mask=None,
|
29 |
+
inputs_embeds=None,
|
30 |
+
start_positions=None,
|
31 |
+
end_positions=None,
|
32 |
+
output_attentions=None,
|
33 |
+
output_hidden_states=None,
|
34 |
+
):
|
35 |
+
outputs = self.bert(
|
36 |
+
input_ids,
|
37 |
+
attention_mask=attention_mask,
|
38 |
+
token_type_ids=token_type_ids,
|
39 |
+
position_ids=position_ids,
|
40 |
+
head_mask=head_mask,
|
41 |
+
inputs_embeds=inputs_embeds,
|
42 |
+
output_attentions=output_attentions,
|
43 |
+
output_hidden_states=output_hidden_states,
|
44 |
+
return_dict=None,
|
45 |
+
)
|
46 |
+
|
47 |
+
sequence_output = outputs[0]
|
48 |
+
|
49 |
+
logits = self.qa_outputs(sequence_output)
|
50 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
51 |
+
start_logits = start_logits.squeeze(-1)
|
52 |
+
end_logits = end_logits.squeeze(-1) # batch_size
|
53 |
+
# print(start_logits.shape, end_logits.shape, start_positions.shape, end_positions.shape)
|
54 |
+
|
55 |
+
total_loss = None
|
56 |
+
if (
|
57 |
+
start_positions is not None and end_positions is not None
|
58 |
+
): # [batch_size/seq_length]
|
59 |
+
# # If we are on multi-GPU, split add a dimension
|
60 |
+
# if len(start_positions.size()) > 1:
|
61 |
+
# start_positions = start_positions.squeeze(-1)
|
62 |
+
# if len(end_positions.size()) > 1:
|
63 |
+
# end_positions = end_positions.squeeze(-1)
|
64 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
65 |
+
# ignored_index = start_logits.size(1)
|
66 |
+
# start_positions.clamp_(0, ignored_index)
|
67 |
+
# end_positions.clamp_(0, ignored_index)
|
68 |
+
|
69 |
+
# start_positions = start_logits.view()
|
70 |
+
|
71 |
+
loss_fct = BCEWithLogitsLoss()
|
72 |
+
|
73 |
+
start_loss = loss = loss_fct(
|
74 |
+
start_logits,
|
75 |
+
start_positions.float(),
|
76 |
+
)
|
77 |
+
end_loss = loss = loss_fct(
|
78 |
+
end_logits,
|
79 |
+
end_positions.float(),
|
80 |
+
)
|
81 |
+
total_loss = (start_loss + end_loss) / 2
|
82 |
+
|
83 |
+
output = (start_logits, end_logits) + outputs[2:]
|
84 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/bert_token_spans.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from torch.nn import CrossEntropyLoss
|
4 |
+
# from transformers import BertPreTrainedModel, BertModel
|
5 |
+
from transformers import ElectraPreTrainedModel, ElectraModel
|
6 |
+
from src.utils.mapper import configmapper
|
7 |
+
|
8 |
+
|
9 |
+
@configmapper.map("models", "bert_token_spans")
|
10 |
+
# class BertModelForTokenAndSpans(BertPreTrainedModel):
|
11 |
+
class BertModelForTokenAndSpans(ElectraPreTrainedModel):
|
12 |
+
def __init__(self, config, num_token_labels=2, num_qa_labels=2):
|
13 |
+
super(BertModelForTokenAndSpans, self).__init__(config)
|
14 |
+
# self.bert = BertModel(config)
|
15 |
+
self.bert = ElectraModel(config)
|
16 |
+
self.num_token_labels = num_token_labels
|
17 |
+
self.num_qa_labels = num_qa_labels
|
18 |
+
# print("Number of Token Labels: ", num_token_labels); exit()
|
19 |
+
|
20 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
21 |
+
self.classifier = nn.Linear(config.hidden_size, num_token_labels)
|
22 |
+
self.qa_outputs = nn.Linear(config.hidden_size, num_qa_labels)
|
23 |
+
self.init_weights()
|
24 |
+
|
25 |
+
def forward(
|
26 |
+
self,
|
27 |
+
input_ids=None,
|
28 |
+
attention_mask=None,
|
29 |
+
token_type_ids=None,
|
30 |
+
position_ids=None,
|
31 |
+
head_mask=None,
|
32 |
+
inputs_embeds=None,
|
33 |
+
start_positions=None,
|
34 |
+
end_positions=None,
|
35 |
+
labels=None, # Token Wise Labels
|
36 |
+
output_attentions=None,
|
37 |
+
output_hidden_states=None,
|
38 |
+
):
|
39 |
+
|
40 |
+
outputs = self.bert(
|
41 |
+
input_ids,
|
42 |
+
attention_mask=attention_mask,
|
43 |
+
token_type_ids=token_type_ids,
|
44 |
+
position_ids=position_ids,
|
45 |
+
head_mask=head_mask,
|
46 |
+
inputs_embeds=inputs_embeds,
|
47 |
+
output_attentions=output_attentions,
|
48 |
+
output_hidden_states=output_hidden_states,
|
49 |
+
return_dict=None,
|
50 |
+
)
|
51 |
+
|
52 |
+
sequence_output = outputs[0]
|
53 |
+
|
54 |
+
qa_logits = self.qa_outputs(sequence_output)
|
55 |
+
start_logits, end_logits = qa_logits.split(1, dim=-1)
|
56 |
+
start_logits = start_logits.squeeze(-1)
|
57 |
+
end_logits = end_logits.squeeze(-1)
|
58 |
+
|
59 |
+
sequence_output = self.dropout(sequence_output)
|
60 |
+
token_logits = self.classifier(sequence_output)
|
61 |
+
|
62 |
+
total_loss = None
|
63 |
+
if (
|
64 |
+
start_positions is not None
|
65 |
+
and end_positions is not None
|
66 |
+
and labels is not None
|
67 |
+
):
|
68 |
+
# If we are on multi-GPU, split add a dimension
|
69 |
+
if len(start_positions.size()) > 1:
|
70 |
+
start_positions = start_positions.squeeze(-1)
|
71 |
+
if len(end_positions.size()) > 1:
|
72 |
+
end_positions = end_positions.squeeze(-1)
|
73 |
+
|
74 |
+
ignored_index = start_logits.size(1)
|
75 |
+
start_positions.clamp_(0, ignored_index)
|
76 |
+
end_positions.clamp_(0, ignored_index)
|
77 |
+
|
78 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
79 |
+
start_loss = loss_fct(start_logits, start_positions)
|
80 |
+
end_loss = loss_fct(end_logits, end_positions)
|
81 |
+
|
82 |
+
loss_fct = CrossEntropyLoss()
|
83 |
+
if attention_mask is not None:
|
84 |
+
active_loss = attention_mask.view(-1) == 1
|
85 |
+
active_logits = token_logits.view(-1, self.num_token_labels)
|
86 |
+
active_labels = torch.where(
|
87 |
+
active_loss,
|
88 |
+
labels.view(-1),
|
89 |
+
torch.tensor(loss_fct.ignore_index).type_as(labels),
|
90 |
+
)
|
91 |
+
token_loss = loss_fct(active_logits, active_labels)
|
92 |
+
else:
|
93 |
+
token_loss = loss_fct(
|
94 |
+
token_logits.view(-1, self.num_token_labels), labels.view(-1)
|
95 |
+
)
|
96 |
+
|
97 |
+
total_loss = (start_loss + end_loss) / 2 + token_loss
|
98 |
+
|
99 |
+
output = (start_logits, end_logits, token_logits) + outputs[2:]
|
100 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/roberta_crf_token.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import RobertaForTokenClassification
|
3 |
+
from torchcrf import CRF
|
4 |
+
from src.utils.mapper import configmapper
|
5 |
+
|
6 |
+
|
7 |
+
@configmapper.map("models", "roberta_crf_token")
|
8 |
+
class RobertaLSTMCRF(RobertaForTokenClassification):
|
9 |
+
def __init__(self, config, lstm_hidden_size, lstm_layers):
|
10 |
+
super().__init__(config)
|
11 |
+
self.lstm = torch.nn.LSTM(
|
12 |
+
input_size=config.hidden_size,
|
13 |
+
hidden_size=lstm_hidden_size,
|
14 |
+
num_layers=lstm_layers,
|
15 |
+
dropout=0.2,
|
16 |
+
batch_first=True,
|
17 |
+
bidirectional=True,
|
18 |
+
)
|
19 |
+
self.crf = CRF(config.num_labels, batch_first=True)
|
20 |
+
|
21 |
+
del self.classifier
|
22 |
+
self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
input_ids,
|
27 |
+
attention_mask=None,
|
28 |
+
token_type_ids=None,
|
29 |
+
labels=None,
|
30 |
+
prediction_mask=None,
|
31 |
+
):
|
32 |
+
|
33 |
+
outputs = self.roberta(
|
34 |
+
input_ids,
|
35 |
+
attention_mask,
|
36 |
+
token_type_ids,
|
37 |
+
output_hidden_states=True,
|
38 |
+
return_dict=False,
|
39 |
+
)
|
40 |
+
# seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
|
41 |
+
|
42 |
+
sequence_output = outputs[0] # outputs[1] is pooled output which is none.
|
43 |
+
|
44 |
+
sequence_output = self.dropout(sequence_output)
|
45 |
+
|
46 |
+
lstm_out, *_ = self.lstm(sequence_output)
|
47 |
+
sequence_output = self.dropout(lstm_out)
|
48 |
+
|
49 |
+
logits = self.classifier(sequence_output)
|
50 |
+
|
51 |
+
## CRF
|
52 |
+
mask = prediction_mask
|
53 |
+
mask = mask[:, : logits.size(1)].contiguous()
|
54 |
+
|
55 |
+
# print(logits)
|
56 |
+
|
57 |
+
if labels is not None:
|
58 |
+
labels = labels[:, : logits.size(1)].contiguous()
|
59 |
+
loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
|
60 |
+
|
61 |
+
tags = self.crf.decode(logits, mask.bool())
|
62 |
+
# print(tags)
|
63 |
+
if labels is not None:
|
64 |
+
return (loss, logits, tags)
|
65 |
+
else:
|
66 |
+
return (logits, tags)
|
src/models/roberta_multi_spans.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import BCEWithLogitsLoss
|
3 |
+
from transformers import RobertaModel
|
4 |
+
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
|
5 |
+
from src.utils.mapper import configmapper
|
6 |
+
|
7 |
+
|
8 |
+
@configmapper.map("models", "roberta_multi_spans")
|
9 |
+
class RobertaForMultiSpans(RobertaPreTrainedModel):
|
10 |
+
def __init__(self, config):
|
11 |
+
super(RobertaForMultiSpans, self).__init__(config)
|
12 |
+
self.roberta = RobertaModel(config)
|
13 |
+
self.num_labels = config.num_labels
|
14 |
+
|
15 |
+
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
16 |
+
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
17 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
18 |
+
self.init_weights()
|
19 |
+
|
20 |
+
def forward(
|
21 |
+
self,
|
22 |
+
input_ids=None,
|
23 |
+
attention_mask=None,
|
24 |
+
token_type_ids=None,
|
25 |
+
position_ids=None,
|
26 |
+
head_mask=None,
|
27 |
+
inputs_embeds=None,
|
28 |
+
start_positions=None,
|
29 |
+
end_positions=None,
|
30 |
+
output_attentions=None,
|
31 |
+
output_hidden_states=None,
|
32 |
+
):
|
33 |
+
outputs = self.roberta(
|
34 |
+
input_ids,
|
35 |
+
attention_mask=attention_mask,
|
36 |
+
token_type_ids=token_type_ids,
|
37 |
+
position_ids=position_ids,
|
38 |
+
head_mask=head_mask,
|
39 |
+
inputs_embeds=inputs_embeds,
|
40 |
+
output_attentions=output_attentions,
|
41 |
+
output_hidden_states=output_hidden_states,
|
42 |
+
return_dict=None,
|
43 |
+
)
|
44 |
+
|
45 |
+
sequence_output = outputs[0]
|
46 |
+
|
47 |
+
logits = self.qa_outputs(sequence_output)
|
48 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
49 |
+
start_logits = start_logits.squeeze(-1)
|
50 |
+
end_logits = end_logits.squeeze(-1) # batch_size
|
51 |
+
# print(start_logits.shape, end_logits.shape, start_positions.shape, end_positions.shape)
|
52 |
+
|
53 |
+
total_loss = None
|
54 |
+
if (
|
55 |
+
start_positions is not None and end_positions is not None
|
56 |
+
): # [batch_size/seq_length]
|
57 |
+
# # If we are on multi-GPU, split add a dimension
|
58 |
+
# if len(start_positions.size()) > 1:
|
59 |
+
# start_positions = start_positions.squeeze(-1)
|
60 |
+
# if len(end_positions.size()) > 1:
|
61 |
+
# end_positions = end_positions.squeeze(-1)
|
62 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
63 |
+
# ignored_index = start_logits.size(1)
|
64 |
+
# start_positions.clamp_(0, ignored_index)
|
65 |
+
# end_positions.clamp_(0, ignored_index)
|
66 |
+
|
67 |
+
# start_positions = start_logits.view()
|
68 |
+
|
69 |
+
loss_fct = BCEWithLogitsLoss()
|
70 |
+
|
71 |
+
start_loss = loss = loss_fct(
|
72 |
+
start_logits,
|
73 |
+
start_positions.float(),
|
74 |
+
)
|
75 |
+
end_loss = loss = loss_fct(
|
76 |
+
end_logits,
|
77 |
+
end_positions.float(),
|
78 |
+
)
|
79 |
+
total_loss = (start_loss + end_loss) / 2
|
80 |
+
|
81 |
+
output = (start_logits, end_logits) + outputs[2:]
|
82 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/roberta_token_spans.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from torch.nn import CrossEntropyLoss
|
4 |
+
from transformers import RobertaModel
|
5 |
+
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
|
6 |
+
from src.utils.mapper import configmapper
|
7 |
+
|
8 |
+
|
9 |
+
@configmapper.map("models", "roberta_token_spans")
|
10 |
+
class RobertaModelForTokenAndSpans(RobertaPreTrainedModel):
|
11 |
+
def __init__(self, config, num_token_labels=2, num_qa_labels=2):
|
12 |
+
super(RobertaModelForTokenAndSpans, self).__init__(config)
|
13 |
+
self.roberta = RobertaModel(config)
|
14 |
+
self.num_token_labels = num_token_labels
|
15 |
+
self.num_qa_labels = num_qa_labels
|
16 |
+
|
17 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
18 |
+
self.classifier = nn.Linear(config.hidden_size, num_token_labels)
|
19 |
+
self.qa_outputs = nn.Linear(config.hidden_size, num_qa_labels)
|
20 |
+
self.init_weights()
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
input_ids=None,
|
25 |
+
attention_mask=None,
|
26 |
+
token_type_ids=None,
|
27 |
+
position_ids=None,
|
28 |
+
head_mask=None,
|
29 |
+
inputs_embeds=None,
|
30 |
+
start_positions=None,
|
31 |
+
end_positions=None,
|
32 |
+
labels=None, # Token Wise Labels
|
33 |
+
output_attentions=None,
|
34 |
+
output_hidden_states=None,
|
35 |
+
):
|
36 |
+
|
37 |
+
outputs = self.roberta(
|
38 |
+
input_ids,
|
39 |
+
attention_mask=attention_mask,
|
40 |
+
token_type_ids=token_type_ids,
|
41 |
+
position_ids=position_ids,
|
42 |
+
head_mask=head_mask,
|
43 |
+
inputs_embeds=inputs_embeds,
|
44 |
+
output_attentions=output_attentions,
|
45 |
+
output_hidden_states=output_hidden_states,
|
46 |
+
return_dict=None,
|
47 |
+
)
|
48 |
+
|
49 |
+
sequence_output = outputs[0]
|
50 |
+
|
51 |
+
qa_logits = self.qa_outputs(sequence_output)
|
52 |
+
start_logits, end_logits = qa_logits.split(1, dim=-1)
|
53 |
+
start_logits = start_logits.squeeze(-1)
|
54 |
+
end_logits = end_logits.squeeze(-1)
|
55 |
+
|
56 |
+
sequence_output = self.dropout(sequence_output)
|
57 |
+
token_logits = self.classifier(sequence_output)
|
58 |
+
|
59 |
+
total_loss = None
|
60 |
+
if (
|
61 |
+
start_positions is not None
|
62 |
+
and end_positions is not None
|
63 |
+
and labels is not None
|
64 |
+
):
|
65 |
+
# If we are on multi-GPU, split add a dimension
|
66 |
+
if len(start_positions.size()) > 1:
|
67 |
+
start_positions = start_positions.squeeze(-1)
|
68 |
+
if len(end_positions.size()) > 1:
|
69 |
+
end_positions = end_positions.squeeze(-1)
|
70 |
+
|
71 |
+
ignored_index = start_logits.size(1)
|
72 |
+
start_positions.clamp_(0, ignored_index)
|
73 |
+
end_positions.clamp_(0, ignored_index)
|
74 |
+
|
75 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
76 |
+
start_loss = loss_fct(start_logits, start_positions)
|
77 |
+
end_loss = loss_fct(end_logits, end_positions)
|
78 |
+
|
79 |
+
loss_fct = CrossEntropyLoss()
|
80 |
+
if attention_mask is not None:
|
81 |
+
active_loss = attention_mask.view(-1) == 1
|
82 |
+
active_logits = token_logits.view(-1, self.num_token_labels)
|
83 |
+
active_labels = torch.where(
|
84 |
+
active_loss,
|
85 |
+
labels.view(-1),
|
86 |
+
torch.tensor(loss_fct.ignore_index).type_as(labels),
|
87 |
+
)
|
88 |
+
token_loss = loss_fct(active_logits, active_labels)
|
89 |
+
else:
|
90 |
+
token_loss = loss_fct(
|
91 |
+
token_logits.view(-1, self.num_token_labels), labels.view(-1)
|
92 |
+
)
|
93 |
+
|
94 |
+
total_loss = (start_loss + end_loss) / 2 + token_loss
|
95 |
+
|
96 |
+
output = (start_logits, end_logits, token_logits) + outputs[2:]
|
97 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/two_layer_nn.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements a two layer Neural Network."""
|
2 |
+
|
3 |
+
from torch.nn import Module, Linear, ReLU
|
4 |
+
from src.utils.mapper import configmapper
|
5 |
+
|
6 |
+
|
7 |
+
@configmapper.map("models", "two_layer_nn")
|
8 |
+
class TwoLayerNN(Module):
|
9 |
+
"""Implements two layer neural network.
|
10 |
+
|
11 |
+
Methods:
|
12 |
+
forward(x_input): Returns the output of the neural network.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, embedding, dims):
|
16 |
+
"""Construct the two layer Neural Network.
|
17 |
+
|
18 |
+
This method is used to initialize the two layer neural network,
|
19 |
+
with a given embedding type and corresponding arguments.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
embedding (torch.nn.Module): The embedding layer for the model.
|
23 |
+
dims (list): List of dimensions for the neural network, input to output.
|
24 |
+
"""
|
25 |
+
super(TwoLayerNN, self).__init__()
|
26 |
+
|
27 |
+
self.embedding = embedding
|
28 |
+
self.linear1 = Linear(dims[0], dims[1])
|
29 |
+
self.relu = ReLU()
|
30 |
+
self.linear2 = Linear(dims[1], dims[2])
|
31 |
+
|
32 |
+
def forward(self, x_input):
|
33 |
+
"""
|
34 |
+
Return the output of the neural network for an input.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x_input (torch.Tensor): The input tensor to the neural network.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
x_output (torch.Tensor): The output tensor for the neural network.
|
41 |
+
"""
|
42 |
+
output = self.embedding(x_input)
|
43 |
+
output = self.linear1(output)
|
44 |
+
output = self.relu(output)
|
45 |
+
x_output = self.linear2(output)
|
46 |
+
return x_output
|
src/modules/__init__.py
ADDED
File without changes
|
src/modules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (166 Bytes). View file
|
|
src/modules/__pycache__/embeddings.cpython-38.pyc
ADDED
Binary file (1.67 kB). View file
|
|
src/modules/__pycache__/preprocessors.cpython-38.pyc
ADDED
Binary file (3.42 kB). View file
|
|
src/modules/__pycache__/tokenizers.cpython-38.pyc
ADDED
Binary file (4.87 kB). View file
|
|
src/modules/activations.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from src.utils.mapper import configmapper
|
3 |
+
|
4 |
+
configmapper.map("activations", "relu")(nn.ReLU)
|
5 |
+
configmapper.map("activations", "logsoftmax")(nn.LogSoftmax)
|
6 |
+
configmapper.map("activations", "softmax")(nn.Softmax)
|
src/modules/embeddings.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Contains various kinds of embeddings like Glove, BERT, etc."""
|
2 |
+
|
3 |
+
from torch.nn import Module, Embedding, Flatten
|
4 |
+
from src.utils.mapper import configmapper
|
5 |
+
|
6 |
+
|
7 |
+
@configmapper.map("embeddings", "glove")
|
8 |
+
class GloveEmbedding(Module):
|
9 |
+
"""Implement Glove based Word Embedding."""
|
10 |
+
|
11 |
+
def __init__(self, embedding_matrix, padding_idx, static=True):
|
12 |
+
"""Construct GloveEmbedding.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
embedding_matrix (torch.Tensor): The matrix contrainining the embedding weights
|
16 |
+
padding_idx (int): The padding index in the tokenizer.
|
17 |
+
static (bool): Whether or not to freeze embeddings.
|
18 |
+
"""
|
19 |
+
super(GloveEmbedding, self).__init__()
|
20 |
+
self.embedding = Embedding.from_pretrained(embedding_matrix)
|
21 |
+
self.embedding.padding_idx = padding_idx
|
22 |
+
if static:
|
23 |
+
self.embedding.weight.required_grad = False
|
24 |
+
self.flatten = Flatten(start_dim=1)
|
25 |
+
|
26 |
+
def forward(self, x_input):
|
27 |
+
"""Pass the input through the embedding.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
x_input (torch.Tensor): The numericalized tokenized input
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
x_output (torch.Tensor): The output from the embedding
|
34 |
+
"""
|
35 |
+
x_output = self.embedding(x_input)
|
36 |
+
x_output = self.flatten(x_output)
|
37 |
+
return x_output
|
src/modules/losses.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"All criterion functions."
|
2 |
+
from torch.nn import MSELoss, CrossEntropyLoss
|
3 |
+
from src.utils.mapper import configmapper
|
4 |
+
|
5 |
+
configmapper.map("losses", "mse")(MSELoss)
|
6 |
+
configmapper.map("losses", "CrossEntropyLoss")(CrossEntropyLoss)
|
src/modules/metrics.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Metrics."""
|
2 |
+
from sklearn.metrics import (
|
3 |
+
mean_squared_error,
|
4 |
+
f1_score,
|
5 |
+
precision_score,
|
6 |
+
recall_score,
|
7 |
+
roc_auc_score,
|
8 |
+
accuracy_score,
|
9 |
+
)
|
10 |
+
from src.utils.mapper import configmapper
|
11 |
+
|
12 |
+
configmapper.map("metrics", "sklearn_f1")(f1_score)
|
13 |
+
configmapper.map("metrics", "sklearn_p")(precision_score)
|
14 |
+
configmapper.map("metrics", "sklearn_r")(recall_score)
|
15 |
+
configmapper.map("metrics", "sklearn_roc")(roc_auc_score)
|
16 |
+
configmapper.map("metrics", "sklearn_acc")(accuracy_score)
|
17 |
+
configmapper.map("metrics", "sklearn_mse")(mean_squared_error)
|
src/modules/optimizers.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
" Method containing activation functions"
|
2 |
+
from torch.optim import Adam, AdamW, SGD
|
3 |
+
from src.utils.mapper import configmapper
|
4 |
+
|
5 |
+
configmapper.map("optimizers", "adam")(Adam)
|
6 |
+
configmapper.map("optimizers", "adam_w")(AdamW)
|
7 |
+
configmapper.map("optimizers", "sgd")(SGD)
|
src/modules/preprocessors.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.modules.tokenizers import *
|
2 |
+
from src.modules.embeddings import *
|
3 |
+
from src.utils.mapper import configmapper
|
4 |
+
|
5 |
+
|
6 |
+
class Preprocessor:
|
7 |
+
def preprocess(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
@configmapper.map("preprocessors", "glove")
|
12 |
+
class GlovePreprocessor(Preprocessor):
|
13 |
+
"""GlovePreprocessor."""
|
14 |
+
|
15 |
+
def __init__(self, config):
|
16 |
+
"""
|
17 |
+
Args:
|
18 |
+
config (src.utils.module.Config): configuration for preprocessor
|
19 |
+
"""
|
20 |
+
super(GlovePreprocessor, self).__init__()
|
21 |
+
self.config = config
|
22 |
+
self.tokenizer = configmapper.get_object(
|
23 |
+
"tokenizers", self.config.main.preprocessor.tokenizer.name
|
24 |
+
)(**self.config.main.preprocessor.tokenizer.init_params.as_dict())
|
25 |
+
self.tokenizer_params = (
|
26 |
+
self.config.main.preprocessor.tokenizer.init_vector_params.as_dict()
|
27 |
+
)
|
28 |
+
|
29 |
+
self.tokenizer.initialize_vectors(**self.tokenizer_params)
|
30 |
+
self.embeddings = configmapper.get_object(
|
31 |
+
"embeddings", self.config.main.preprocessor.embedding.name
|
32 |
+
)(
|
33 |
+
self.tokenizer.text_field.vocab.vectors,
|
34 |
+
self.tokenizer.text_field.vocab.stoi[self.tokenizer.text_field.pad_token],
|
35 |
+
)
|
36 |
+
|
37 |
+
def preprocess(self, model_config, data_config):
|
38 |
+
train_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
39 |
+
data_config.train, self.tokenizer
|
40 |
+
)
|
41 |
+
val_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
42 |
+
data_config.val, self.tokenizer
|
43 |
+
)
|
44 |
+
model = configmapper.get_object("models", model_config.name)(
|
45 |
+
self.embeddings, **model_config.params.as_dict()
|
46 |
+
)
|
47 |
+
|
48 |
+
return model, train_dataset, val_dataset
|
49 |
+
|
50 |
+
|
51 |
+
@configmapper.map("preprocessors", "clozePreprocessor")
|
52 |
+
class ClozePreprocessor(Preprocessor):
|
53 |
+
"""GlovePreprocessor."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
config (src.utils.module.Config): configuration for preprocessor
|
59 |
+
"""
|
60 |
+
super(ClozePreprocessor, self).__init__()
|
61 |
+
self.config = config
|
62 |
+
self.tokenizer = configmapper.get_object(
|
63 |
+
"tokenizers", self.config.main.preprocessor.tokenizer.name
|
64 |
+
).from_pretrained(
|
65 |
+
**self.config.main.preprocessor.tokenizer.init_params.as_dict()
|
66 |
+
)
|
67 |
+
|
68 |
+
def preprocess(self, model_config, data_config):
|
69 |
+
train_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
70 |
+
data_config.train, self.tokenizer
|
71 |
+
)
|
72 |
+
val_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
73 |
+
data_config.val, self.tokenizer
|
74 |
+
)
|
75 |
+
model = configmapper.get_object("models", model_config.name).from_pretrained(
|
76 |
+
**model_config.params.as_dict()
|
77 |
+
)
|
78 |
+
|
79 |
+
return model, train_dataset, val_dataset
|
80 |
+
|
81 |
+
|
82 |
+
@configmapper.map("preprocessors", "transformersConcretenessPreprocessor")
|
83 |
+
class TransformersConcretenessPreprocessor(Preprocessor):
|
84 |
+
"""BertConcretenessPreprocessor."""
|
85 |
+
|
86 |
+
def __init__(self, config):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
config (src.utils.module.Config): configuration for preprocessor
|
90 |
+
"""
|
91 |
+
super(TransformersConcretenessPreprocessor, self).__init__()
|
92 |
+
self.config = config
|
93 |
+
self.tokenizer = configmapper.get_object(
|
94 |
+
"tokenizers", self.config.main.preprocessor.tokenizer.name
|
95 |
+
).from_pretrained(
|
96 |
+
**self.config.main.preprocessor.tokenizer.init_params.as_dict()
|
97 |
+
)
|
98 |
+
|
99 |
+
def preprocess(self, model_config, data_config):
|
100 |
+
|
101 |
+
train_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
102 |
+
data_config.train, self.tokenizer
|
103 |
+
)
|
104 |
+
val_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
105 |
+
data_config.val, self.tokenizer
|
106 |
+
)
|
107 |
+
|
108 |
+
model = configmapper.get_object("models", model_config.name)(
|
109 |
+
**model_config.params.as_dict()
|
110 |
+
)
|
111 |
+
|
112 |
+
return model, train_dataset, val_dataset
|
src/modules/schedulers.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim.lr_scheduler import (
|
2 |
+
StepLR,
|
3 |
+
CosineAnnealingLR,
|
4 |
+
ReduceLROnPlateau,
|
5 |
+
CyclicLR,
|
6 |
+
CosineAnnealingWarmRestarts,
|
7 |
+
)
|
8 |
+
from src.utils.mapper import configmapper
|
9 |
+
|
10 |
+
configmapper.map("schedulers", "step")(StepLR)
|
11 |
+
configmapper.map("schedulers", "cosineanneal")(CosineAnnealingLR)
|
12 |
+
configmapper.map("schedulers", "reduceplateau")(ReduceLROnPlateau)
|
13 |
+
configmapper.map("schedulers", "cyclic")(CyclicLR)
|
14 |
+
configmapper.map("schedulers", "cosineannealrestart")(CosineAnnealingWarmRestarts)
|
src/modules/tokenizers.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Contains tokenizers like GloveTokenizers and BERT Tokenizer."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
# from torchtext.vocab import GloVe
|
5 |
+
# from torchtext.data import Field, TabularDataset
|
6 |
+
from src.utils.mapper import configmapper
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class Tokenizer:
|
11 |
+
"""Abstract Class for Tokenizers."""
|
12 |
+
|
13 |
+
def tokenize(self):
|
14 |
+
"""Abstract Method for tokenization."""
|
15 |
+
|
16 |
+
|
17 |
+
@configmapper.map("tokenizers", "glove")
|
18 |
+
class GloveTokenizer(Tokenizer):
|
19 |
+
"""Implement GloveTokenizer for tokenizing text for Glove Embeddings.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
embeddings (torchtext.vocab.Vectors): Loaded pre-trained embeddings.
|
23 |
+
text_field (torchtext.data.Field): Text_field for vector creation.
|
24 |
+
|
25 |
+
Methods:
|
26 |
+
__init__(self, name='840B', dim='300', cache='../embeddings/') : Constructor method
|
27 |
+
initialize_vectors(fix_length=4, tokenize='spacy', file_path="../data/imperceptibility
|
28 |
+
/Concreteness Ratings/train/forty.csv",
|
29 |
+
file_format='tsv', fields=None): Initialize vocab vectors based on data.
|
30 |
+
|
31 |
+
tokenize(x_input, **initializer_params): Tokenize given input and return the output.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, name="840B", dim="300", cache="../embeddings/"):
|
35 |
+
"""Construct GloveTokenizer.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
name (str): Name of the GloVe embedding file
|
39 |
+
dim (str): Dimensions of the Glove embedding file
|
40 |
+
cache (str): Path to the embeddings directory
|
41 |
+
"""
|
42 |
+
super(GloveTokenizer, self).__init__()
|
43 |
+
self.embeddings = GloVe(name=name, dim=dim, cache=cache)
|
44 |
+
self.text_field = None
|
45 |
+
|
46 |
+
def initialize_vectors(
|
47 |
+
self,
|
48 |
+
fix_length=4,
|
49 |
+
tokenize="spacy",
|
50 |
+
tokenizer_file_paths=None,
|
51 |
+
file_format="tsv",
|
52 |
+
fields=None,
|
53 |
+
):
|
54 |
+
"""Initialize words/sequences based on GloVe embedding.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
fields (list): The list containing the fields to be taken
|
58 |
+
and processed from the file (see documentation for
|
59 |
+
torchtext.data.TabularDataset)
|
60 |
+
fix_length (int): The length of the tokenized text,
|
61 |
+
padding or cropping is done accordingly
|
62 |
+
tokenize (function or string): Method to tokenize the data.
|
63 |
+
If 'spacy' uses spacy tokenizer,
|
64 |
+
else the specified method.
|
65 |
+
tokenizer_file_paths (list of str): The paths of the files containing the data
|
66 |
+
format (str): The format of the file : 'csv', 'tsv' or 'json'
|
67 |
+
"""
|
68 |
+
text_field = Field(batch_first=True, fix_length=fix_length, tokenize=tokenize)
|
69 |
+
tab_dats = [
|
70 |
+
TabularDataset(
|
71 |
+
i, format=file_format, fields={k: (k, text_field) for k in fields}
|
72 |
+
)
|
73 |
+
for i in tokenizer_file_paths
|
74 |
+
]
|
75 |
+
text_field.build_vocab(*tab_dats)
|
76 |
+
text_field.vocab.load_vectors(self.embeddings)
|
77 |
+
self.text_field = text_field
|
78 |
+
|
79 |
+
def tokenize(self, x_input, **init_vector__params):
|
80 |
+
"""Tokenize given input based on initialized vectors.
|
81 |
+
|
82 |
+
Initialize the vectors with given parameters if not already initialized.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
x_input (str): Unprocessed input text to be tokenized
|
86 |
+
**initializer_params (Keyword arguments): Parameters to initialize vectors
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
x_output (str): Processed and tokenized text
|
90 |
+
"""
|
91 |
+
if self.text_field is None:
|
92 |
+
self.initialize_vectors(**init_vector__params)
|
93 |
+
try:
|
94 |
+
x_output = torch.squeeze(
|
95 |
+
self.text_field.process([self.text_field.preprocess(x_input)])
|
96 |
+
)
|
97 |
+
except Exception as e:
|
98 |
+
print(x_input)
|
99 |
+
print(self.text_field.preprocess(x_input))
|
100 |
+
print(e)
|
101 |
+
return x_output
|
102 |
+
|
103 |
+
|
104 |
+
@configmapper.map("tokenizers", "AutoTokenizer")
|
105 |
+
class AutoTokenizer(AutoTokenizer):
|
106 |
+
def __init__(self, *args):
|
107 |
+
super(AutoTokenizer, self).__init__()
|
src/trainers/__init__.py
ADDED
File without changes
|
src/trainers/base_trainer.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from src.modules.optimizers import *
|
5 |
+
from src.modules.embeddings import *
|
6 |
+
from src.modules.schedulers import *
|
7 |
+
from src.modules.tokenizers import *
|
8 |
+
from src.modules.metrics import *
|
9 |
+
from src.modules.losses import *
|
10 |
+
from src.utils.misc import *
|
11 |
+
from src.utils.logger import Logger
|
12 |
+
from src.utils.mapper import configmapper
|
13 |
+
from src.utils.configuration import Config
|
14 |
+
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
@configmapper.map("trainers", "base")
|
20 |
+
class BaseTrainer:
|
21 |
+
def __init__(self, config):
|
22 |
+
self._config = config
|
23 |
+
self.metrics = {
|
24 |
+
configmapper.get_object("metrics", metric["type"]): metric["params"]
|
25 |
+
for metric in self._config.main_config.metrics
|
26 |
+
}
|
27 |
+
self.train_config = self._config.train
|
28 |
+
self.val_config = self._config.val
|
29 |
+
self.log_label = self.train_config.log.log_label
|
30 |
+
if self.train_config.log_and_val_interval is not None:
|
31 |
+
self.val_log_together = True
|
32 |
+
print("Logging with label: ", self.log_label)
|
33 |
+
|
34 |
+
def train(self, model, train_dataset, val_dataset=None, logger=None):
|
35 |
+
device = torch.device(self._config.main_config.device.name)
|
36 |
+
model.to(device)
|
37 |
+
optim_params = self.train_config.optimizer.params
|
38 |
+
if optim_params:
|
39 |
+
optimizer = configmapper.get_object(
|
40 |
+
"optimizers", self.train_config.optimizer.type
|
41 |
+
)(model.parameters(), **optim_params.as_dict())
|
42 |
+
else:
|
43 |
+
optimizer = configmapper.get_object(
|
44 |
+
"optimizers", self.train_config.optimizer.type
|
45 |
+
)(model.parameters())
|
46 |
+
|
47 |
+
if self.train_config.scheduler is not None:
|
48 |
+
scheduler_params = self.train_config.scheduler.params
|
49 |
+
if scheduler_params:
|
50 |
+
scheduler = configmapper.get_object(
|
51 |
+
"schedulers", self.train_config.scheduler.type
|
52 |
+
)(optimizer, **scheduler_params.as_dict())
|
53 |
+
else:
|
54 |
+
scheduler = configmapper.get_object(
|
55 |
+
"schedulers", self.train_config.scheduler.type
|
56 |
+
)(optimizer)
|
57 |
+
|
58 |
+
criterion_params = self.train_config.criterion.params
|
59 |
+
if criterion_params:
|
60 |
+
criterion = configmapper.get_object(
|
61 |
+
"losses", self.train_config.criterion.type
|
62 |
+
)(**criterion_params.as_dict())
|
63 |
+
else:
|
64 |
+
criterion = configmapper.get_object(
|
65 |
+
"losses", self.train_config.criterion.type
|
66 |
+
)()
|
67 |
+
if "custom_collate_fn" in dir(train_dataset):
|
68 |
+
train_loader = DataLoader(
|
69 |
+
dataset=train_dataset,
|
70 |
+
collate_fn=train_dataset.custom_collate_fn,
|
71 |
+
**self.train_config.loader_params.as_dict(),
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
train_loader = DataLoader(
|
75 |
+
dataset=train_dataset, **self.train_config.loader_params.as_dict()
|
76 |
+
)
|
77 |
+
# train_logger = Logger(**self.train_config.log.logger_params.as_dict())
|
78 |
+
|
79 |
+
max_epochs = self.train_config.max_epochs
|
80 |
+
batch_size = self.train_config.loader_params.batch_size
|
81 |
+
|
82 |
+
if self.val_log_together:
|
83 |
+
val_interval = self.train_config.log_and_val_interval
|
84 |
+
log_interval = val_interval
|
85 |
+
else:
|
86 |
+
val_interval = self.train_config.val_interval
|
87 |
+
log_interval = self.train_config.log.log_interval
|
88 |
+
|
89 |
+
if logger is None:
|
90 |
+
train_logger = Logger(**self.train_config.log.logger_params.as_dict())
|
91 |
+
else:
|
92 |
+
train_logger = logger
|
93 |
+
|
94 |
+
train_log_values = self.train_config.log.values.as_dict()
|
95 |
+
|
96 |
+
best_score = (
|
97 |
+
-math.inf if self.train_config.save_on.desired == "max" else math.inf
|
98 |
+
)
|
99 |
+
save_on_score = self.train_config.save_on.score
|
100 |
+
best_step = -1
|
101 |
+
best_model = None
|
102 |
+
|
103 |
+
best_hparam_list = None
|
104 |
+
best_hparam_name_list = None
|
105 |
+
best_metrics_list = None
|
106 |
+
best_metrics_name_list = None
|
107 |
+
|
108 |
+
# print("\nTraining\n")
|
109 |
+
# print(max_steps)
|
110 |
+
|
111 |
+
global_step = 0
|
112 |
+
for epoch in range(1, max_epochs + 1):
|
113 |
+
print(
|
114 |
+
"Epoch: {}/{}, Global Step: {}".format(epoch, max_epochs, global_step)
|
115 |
+
)
|
116 |
+
train_loss = 0
|
117 |
+
val_loss = 0
|
118 |
+
|
119 |
+
if(self.train_config.label_type=='float'):
|
120 |
+
all_labels = torch.FloatTensor().to(device)
|
121 |
+
else:
|
122 |
+
all_labels = torch.LongTensor().to(device)
|
123 |
+
|
124 |
+
all_outputs = torch.Tensor().to(device)
|
125 |
+
|
126 |
+
train_scores = None
|
127 |
+
val_scores = None
|
128 |
+
|
129 |
+
pbar = tqdm(total=math.ceil(len(train_dataset) / batch_size))
|
130 |
+
pbar.set_description("Epoch " + str(epoch))
|
131 |
+
|
132 |
+
val_counter = 0
|
133 |
+
|
134 |
+
for step, batch in enumerate(train_loader):
|
135 |
+
model.train()
|
136 |
+
optimizer.zero_grad()
|
137 |
+
inputs, labels = batch
|
138 |
+
|
139 |
+
if(self.train_config.label_type=='float'): ##Specific to Float Type
|
140 |
+
labels = labels.float()
|
141 |
+
|
142 |
+
for key in inputs:
|
143 |
+
inputs[key] = inputs[key].to(device)
|
144 |
+
labels = labels.to(device)
|
145 |
+
outputs = model(inputs)
|
146 |
+
loss = criterion(torch.squeeze(outputs), labels)
|
147 |
+
loss.backward()
|
148 |
+
|
149 |
+
all_labels = torch.cat((all_labels, labels), 0)
|
150 |
+
|
151 |
+
if (self.train_config.label_type=='float'):
|
152 |
+
all_outputs = torch.cat((all_outputs, outputs), 0)
|
153 |
+
else:
|
154 |
+
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
|
155 |
+
|
156 |
+
|
157 |
+
train_loss += loss.item()
|
158 |
+
optimizer.step()
|
159 |
+
|
160 |
+
if self.train_config.scheduler is not None:
|
161 |
+
if isinstance(scheduler, ReduceLROnPlateau):
|
162 |
+
scheduler.step(train_loss / (step + 1))
|
163 |
+
else:
|
164 |
+
scheduler.step()
|
165 |
+
|
166 |
+
# print(train_loss)
|
167 |
+
# print(step+1)
|
168 |
+
|
169 |
+
pbar.set_postfix_str(f"Train Loss: {train_loss /(step+1)}")
|
170 |
+
pbar.update(1)
|
171 |
+
|
172 |
+
global_step += 1
|
173 |
+
|
174 |
+
# Need to check if we want global_step or local_step
|
175 |
+
|
176 |
+
if val_dataset is not None and (global_step - 1) % val_interval == 0:
|
177 |
+
# print("\nEvaluating\n")
|
178 |
+
val_scores = self.val(
|
179 |
+
model,
|
180 |
+
val_dataset,
|
181 |
+
criterion,
|
182 |
+
device,
|
183 |
+
global_step,
|
184 |
+
train_logger,
|
185 |
+
train_log_values,
|
186 |
+
)
|
187 |
+
|
188 |
+
#save_flag = 0
|
189 |
+
if self.train_config.save_on is not None:
|
190 |
+
|
191 |
+
## BEST SCORES UPDATING
|
192 |
+
|
193 |
+
train_scores = self.get_scores(
|
194 |
+
train_loss,
|
195 |
+
global_step,
|
196 |
+
self.train_config.criterion.type,
|
197 |
+
all_outputs,
|
198 |
+
all_labels,
|
199 |
+
)
|
200 |
+
|
201 |
+
best_score, best_step, save_flag = self.check_best(
|
202 |
+
val_scores, save_on_score, best_score, global_step
|
203 |
+
)
|
204 |
+
|
205 |
+
store_dict = {
|
206 |
+
"model_state_dict": model.state_dict(),
|
207 |
+
"best_step": best_step,
|
208 |
+
"best_score": best_score,
|
209 |
+
"save_on_score": save_on_score,
|
210 |
+
}
|
211 |
+
|
212 |
+
path = self.train_config.save_on.best_path.format(
|
213 |
+
self.log_label
|
214 |
+
)
|
215 |
+
|
216 |
+
self.save(store_dict, path, save_flag)
|
217 |
+
|
218 |
+
if save_flag and train_log_values["hparams"] is not None:
|
219 |
+
(
|
220 |
+
best_hparam_list,
|
221 |
+
best_hparam_name_list,
|
222 |
+
best_metrics_list,
|
223 |
+
best_metrics_name_list,
|
224 |
+
) = self.update_hparams(
|
225 |
+
train_scores, val_scores, desc="best_val"
|
226 |
+
)
|
227 |
+
# pbar.close()
|
228 |
+
if (global_step - 1) % log_interval == 0:
|
229 |
+
# print("\nLogging\n")
|
230 |
+
train_loss_name = self.train_config.criterion.type
|
231 |
+
metric_list = [
|
232 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
|
233 |
+
for metric in self.metrics
|
234 |
+
]
|
235 |
+
metric_name_list = [
|
236 |
+
metric['type'] for metric in self._config.main_config.metrics
|
237 |
+
]
|
238 |
+
|
239 |
+
train_scores = self.log(
|
240 |
+
train_loss / (step + 1),
|
241 |
+
train_loss_name,
|
242 |
+
metric_list,
|
243 |
+
metric_name_list,
|
244 |
+
train_logger,
|
245 |
+
train_log_values,
|
246 |
+
global_step,
|
247 |
+
append_text=self.train_config.append_text,
|
248 |
+
)
|
249 |
+
pbar.close()
|
250 |
+
if not os.path.exists(self.train_config.checkpoint.checkpoint_dir):
|
251 |
+
os.makedirs(self.train_config.checkpoint.checkpoint_dir)
|
252 |
+
|
253 |
+
if self.train_config.save_after_epoch:
|
254 |
+
store_dict = {
|
255 |
+
"model_state_dict": model.state_dict(),
|
256 |
+
}
|
257 |
+
|
258 |
+
path = f"{self.train_config.checkpoint.checkpoint_dir}_{str(self.train_config.log.log_label)}_{str(epoch)}.pth"
|
259 |
+
|
260 |
+
self.save(store_dict, path, save_flag=1)
|
261 |
+
|
262 |
+
if epoch == max_epochs:
|
263 |
+
# print("\nEvaluating\n")
|
264 |
+
val_scores = self.val(
|
265 |
+
model,
|
266 |
+
val_dataset,
|
267 |
+
criterion,
|
268 |
+
device,
|
269 |
+
global_step,
|
270 |
+
train_logger,
|
271 |
+
train_log_values,
|
272 |
+
)
|
273 |
+
|
274 |
+
# print("\nLogging\n")
|
275 |
+
train_loss_name = self.train_config.criterion.type
|
276 |
+
metric_list = [
|
277 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(),**self.metrics[metric])
|
278 |
+
for metric in self.metrics
|
279 |
+
]
|
280 |
+
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
|
281 |
+
|
282 |
+
train_scores = self.log(
|
283 |
+
train_loss / len(train_loader),
|
284 |
+
train_loss_name,
|
285 |
+
metric_list,
|
286 |
+
metric_name_list,
|
287 |
+
train_logger,
|
288 |
+
train_log_values,
|
289 |
+
global_step,
|
290 |
+
append_text=self.train_config.append_text,
|
291 |
+
)
|
292 |
+
|
293 |
+
if self.train_config.save_on is not None:
|
294 |
+
|
295 |
+
## BEST SCORES UPDATING
|
296 |
+
|
297 |
+
train_scores = self.get_scores(
|
298 |
+
train_loss,
|
299 |
+
len(train_loader),
|
300 |
+
self.train_config.criterion.type,
|
301 |
+
all_outputs,
|
302 |
+
all_labels,
|
303 |
+
)
|
304 |
+
|
305 |
+
best_score, best_step, save_flag = self.check_best(
|
306 |
+
val_scores, save_on_score, best_score, global_step
|
307 |
+
)
|
308 |
+
|
309 |
+
store_dict = {
|
310 |
+
"model_state_dict": model.state_dict(),
|
311 |
+
"best_step": best_step,
|
312 |
+
"best_score": best_score,
|
313 |
+
"save_on_score": save_on_score,
|
314 |
+
}
|
315 |
+
|
316 |
+
path = self.train_config.save_on.best_path.format(self.log_label)
|
317 |
+
|
318 |
+
self.save(store_dict, path, save_flag)
|
319 |
+
|
320 |
+
if save_flag and train_log_values["hparams"] is not None:
|
321 |
+
(
|
322 |
+
best_hparam_list,
|
323 |
+
best_hparam_name_list,
|
324 |
+
best_metrics_list,
|
325 |
+
best_metrics_name_list,
|
326 |
+
) = self.update_hparams(train_scores, val_scores, desc="best_val")
|
327 |
+
|
328 |
+
## FINAL SCORES UPDATING + STORING
|
329 |
+
train_scores = self.get_scores(
|
330 |
+
train_loss,
|
331 |
+
len(train_loader),
|
332 |
+
self.train_config.criterion.type,
|
333 |
+
all_outputs,
|
334 |
+
all_labels,
|
335 |
+
)
|
336 |
+
|
337 |
+
store_dict = {
|
338 |
+
"model_state_dict": model.state_dict(),
|
339 |
+
"final_step": global_step,
|
340 |
+
"final_score": train_scores[save_on_score],
|
341 |
+
"save_on_score": save_on_score,
|
342 |
+
}
|
343 |
+
|
344 |
+
path = self.train_config.save_on.final_path.format(self.log_label)
|
345 |
+
|
346 |
+
self.save(store_dict, path, save_flag=1)
|
347 |
+
if train_log_values["hparams"] is not None:
|
348 |
+
(
|
349 |
+
final_hparam_list,
|
350 |
+
final_hparam_name_list,
|
351 |
+
final_metrics_list,
|
352 |
+
final_metrics_name_list,
|
353 |
+
) = self.update_hparams(train_scores, val_scores, desc="final")
|
354 |
+
train_logger.save_hyperparams(
|
355 |
+
best_hparam_list,
|
356 |
+
best_hparam_name_list,
|
357 |
+
[int(self.log_label),] + best_metrics_list + final_metrics_list,
|
358 |
+
["hparams/log_label",]
|
359 |
+
+ best_metrics_name_list
|
360 |
+
+ final_metrics_name_list,
|
361 |
+
)
|
362 |
+
#
|
363 |
+
|
364 |
+
## Need to check if we want same loggers of different loggers for train and eval
|
365 |
+
## Evaluate
|
366 |
+
|
367 |
+
def get_scores(self, loss, divisor, loss_name, all_outputs, all_labels):
|
368 |
+
|
369 |
+
avg_loss = loss / divisor
|
370 |
+
|
371 |
+
metric_list = [
|
372 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
|
373 |
+
for metric in self.metrics
|
374 |
+
]
|
375 |
+
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
|
376 |
+
|
377 |
+
return dict(zip([loss_name,] + metric_name_list, [avg_loss,] + metric_list,))
|
378 |
+
|
379 |
+
def check_best(self, val_scores, save_on_score, best_score, global_step):
|
380 |
+
save_flag = 0
|
381 |
+
best_step = global_step
|
382 |
+
if self.train_config.save_on.desired == "min":
|
383 |
+
if val_scores[save_on_score] < best_score:
|
384 |
+
save_flag = 1
|
385 |
+
best_score = val_scores[save_on_score]
|
386 |
+
best_step = global_step
|
387 |
+
else:
|
388 |
+
if val_scores[save_on_score] > best_score:
|
389 |
+
save_flag = 1
|
390 |
+
best_score = val_scores[save_on_score]
|
391 |
+
best_step = global_step
|
392 |
+
return best_score, best_step, save_flag
|
393 |
+
|
394 |
+
def update_hparams(self, train_scores, val_scores, desc):
|
395 |
+
hparam_list = []
|
396 |
+
hparam_name_list = []
|
397 |
+
for hparam in self.train_config.log.values.hparams:
|
398 |
+
hparam_list.append(get_item_in_config(self._config, hparam["path"]))
|
399 |
+
if isinstance(hparam_list[-1], Config):
|
400 |
+
hparam_list[-1] = hparam_list[-1].as_dict()
|
401 |
+
hparam_name_list.append(hparam["name"])
|
402 |
+
|
403 |
+
val_keys, val_values = zip(*val_scores.items())
|
404 |
+
train_keys, train_values = zip(*train_scores.items())
|
405 |
+
val_keys = list(val_keys)
|
406 |
+
train_keys = list(train_keys)
|
407 |
+
val_values = list(val_values)
|
408 |
+
train_values = list(train_values)
|
409 |
+
for i, key in enumerate(val_keys):
|
410 |
+
val_keys[i] = f"hparams/{desc}_val_" + val_keys[i]
|
411 |
+
for i, key in enumerate(train_keys):
|
412 |
+
train_keys[i] = f"hparams/{desc}_train_" + train_keys[i]
|
413 |
+
# train_logger.save_hyperparams(hparam_list, hparam_name_list,train_values+val_values,train_keys+val_keys, )
|
414 |
+
return (
|
415 |
+
hparam_list,
|
416 |
+
hparam_name_list,
|
417 |
+
train_values + val_values,
|
418 |
+
train_keys + val_keys,
|
419 |
+
)
|
420 |
+
|
421 |
+
def save(self, store_dict, path, save_flag=0):
|
422 |
+
if save_flag:
|
423 |
+
dirs = "/".join(path.split("/")[:-1])
|
424 |
+
if not os.path.exists(dirs):
|
425 |
+
os.makedirs(dirs)
|
426 |
+
torch.save(store_dict, path)
|
427 |
+
|
428 |
+
def log(
|
429 |
+
self,
|
430 |
+
loss,
|
431 |
+
loss_name,
|
432 |
+
metric_list,
|
433 |
+
metric_name_list,
|
434 |
+
logger,
|
435 |
+
log_values,
|
436 |
+
global_step,
|
437 |
+
append_text,
|
438 |
+
):
|
439 |
+
|
440 |
+
return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,))
|
441 |
+
|
442 |
+
loss_name = f"{append_text}_{self.log_label}_{loss_name}"
|
443 |
+
if log_values["loss"]:
|
444 |
+
logger.save_params(
|
445 |
+
[loss],
|
446 |
+
[loss_name],
|
447 |
+
combine=True,
|
448 |
+
combine_name="losses",
|
449 |
+
global_step=global_step,
|
450 |
+
)
|
451 |
+
|
452 |
+
for i in range(len(metric_name_list)):
|
453 |
+
metric_name_list[
|
454 |
+
i
|
455 |
+
] = f"{append_text}_{self.log_label}_{metric_name_list[i]}"
|
456 |
+
if log_values["metrics"]:
|
457 |
+
logger.save_params(
|
458 |
+
metric_list,
|
459 |
+
metric_name_list,
|
460 |
+
combine=True,
|
461 |
+
combine_name="metrics",
|
462 |
+
global_step=global_step,
|
463 |
+
)
|
464 |
+
# print(hparams_list)
|
465 |
+
# print(hparam_name_list)
|
466 |
+
|
467 |
+
# for k,v in dict(zip([loss_name],[loss])).items():
|
468 |
+
# print(f"{k}:{v}")
|
469 |
+
# for k,v in dict(zip(metric_name_list,metric_list)).items():
|
470 |
+
# print(f"{k}:{v}")
|
471 |
+
return return_dic
|
472 |
+
|
473 |
+
def val(
|
474 |
+
self,
|
475 |
+
model,
|
476 |
+
dataset,
|
477 |
+
criterion,
|
478 |
+
device,
|
479 |
+
global_step,
|
480 |
+
train_logger=None,
|
481 |
+
train_log_values=None,
|
482 |
+
log=True,
|
483 |
+
):
|
484 |
+
append_text = self.val_config.append_text
|
485 |
+
if train_logger is not None:
|
486 |
+
val_logger = train_logger
|
487 |
+
else:
|
488 |
+
val_logger = Logger(**self.val_config.log.logger_params.as_dict())
|
489 |
+
|
490 |
+
if train_log_values is not None:
|
491 |
+
val_log_values = train_log_values
|
492 |
+
else:
|
493 |
+
val_log_values = self.val_config.log.values.as_dict()
|
494 |
+
if "custom_collate_fn" in dir(dataset):
|
495 |
+
val_loader = DataLoader(
|
496 |
+
dataset=dataset,
|
497 |
+
collate_fn=dataset.custom_collate_fn,
|
498 |
+
**self.val_config.loader_params.as_dict(),
|
499 |
+
)
|
500 |
+
else:
|
501 |
+
val_loader = DataLoader(
|
502 |
+
dataset=dataset, **self.val_config.loader_params.as_dict()
|
503 |
+
)
|
504 |
+
|
505 |
+
all_outputs = torch.Tensor().to(device)
|
506 |
+
if(self.train_config.label_type=='float'):
|
507 |
+
all_labels = torch.FloatTensor().to(device)
|
508 |
+
else:
|
509 |
+
all_labels = torch.LongTensor().to(device)
|
510 |
+
|
511 |
+
batch_size = self.val_config.loader_params.batch_size
|
512 |
+
|
513 |
+
with torch.no_grad():
|
514 |
+
model.eval()
|
515 |
+
val_loss = 0
|
516 |
+
for j, batch in enumerate(val_loader):
|
517 |
+
|
518 |
+
inputs, labels = batch
|
519 |
+
|
520 |
+
if(self.train_config.label_type=='float'):
|
521 |
+
labels = labels.float()
|
522 |
+
|
523 |
+
for key in inputs:
|
524 |
+
inputs[key] = inputs[key].to(device)
|
525 |
+
labels = labels.to(device)
|
526 |
+
|
527 |
+
outputs = model(inputs)
|
528 |
+
loss = criterion(torch.squeeze(outputs), labels)
|
529 |
+
val_loss += loss.item()
|
530 |
+
|
531 |
+
all_labels = torch.cat((all_labels, labels), 0)
|
532 |
+
|
533 |
+
if (self.train_config.label_type=='float'):
|
534 |
+
all_outputs = torch.cat((all_outputs, outputs), 0)
|
535 |
+
else:
|
536 |
+
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
|
537 |
+
|
538 |
+
val_loss = val_loss / len(val_loader)
|
539 |
+
|
540 |
+
val_loss_name = self.train_config.criterion.type
|
541 |
+
|
542 |
+
# print(all_outputs, all_labels)
|
543 |
+
metric_list = [
|
544 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
|
545 |
+
for metric in self.metrics
|
546 |
+
]
|
547 |
+
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
|
548 |
+
return_dic = dict(
|
549 |
+
zip([val_loss_name,] + metric_name_list, [val_loss,] + metric_list,)
|
550 |
+
)
|
551 |
+
if log:
|
552 |
+
val_scores = self.log(
|
553 |
+
val_loss,
|
554 |
+
val_loss_name,
|
555 |
+
metric_list,
|
556 |
+
metric_name_list,
|
557 |
+
val_logger,
|
558 |
+
val_log_values,
|
559 |
+
global_step,
|
560 |
+
append_text,
|
561 |
+
)
|
562 |
+
return val_scores
|
563 |
+
return return_dic
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (164 Bytes). View file
|
|