This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. src/datasets/__init__.py +7 -0
  2. src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  3. src/datasets/__pycache__/toxic_spans_crf_3cls_tokens.cpython-38.pyc +0 -0
  4. src/datasets/__pycache__/toxic_spans_crf_tokens.cpython-38.pyc +0 -0
  5. src/datasets/__pycache__/toxic_spans_multi_spans.cpython-38.pyc +0 -0
  6. src/datasets/__pycache__/toxic_spans_spans.cpython-38.pyc +0 -0
  7. src/datasets/__pycache__/toxic_spans_tokens.cpython-38.pyc +0 -0
  8. src/datasets/__pycache__/toxic_spans_tokens_3cls.cpython-38.pyc +0 -0
  9. src/datasets/__pycache__/toxic_spans_tokens_spans.cpython-38.pyc +0 -0
  10. src/datasets/toxic_spans_crf_3cls_tokens.py +132 -0
  11. src/datasets/toxic_spans_crf_tokens.py +111 -0
  12. src/datasets/toxic_spans_multi_spans.py +237 -0
  13. src/datasets/toxic_spans_spans.py +238 -0
  14. src/datasets/toxic_spans_tokens.py +81 -0
  15. src/datasets/toxic_spans_tokens_3cls.py +102 -0
  16. src/datasets/toxic_spans_tokens_spans.py +269 -0
  17. src/models/__init__.py +7 -0
  18. src/models/__pycache__/__init__.cpython-38.pyc +0 -0
  19. src/models/__pycache__/auto_models.cpython-38.pyc +0 -0
  20. src/models/__pycache__/bert_crf_token.cpython-38.pyc +0 -0
  21. src/models/__pycache__/bert_multi_spans.cpython-38.pyc +0 -0
  22. src/models/__pycache__/bert_token_spans.cpython-38.pyc +0 -0
  23. src/models/__pycache__/roberta_crf_token.cpython-38.pyc +0 -0
  24. src/models/__pycache__/roberta_multi_spans.cpython-38.pyc +0 -0
  25. src/models/__pycache__/roberta_token_spans.cpython-38.pyc +0 -0
  26. src/models/auto_models.py +6 -0
  27. src/models/bert_crf_token.py +72 -0
  28. src/models/bert_multi_spans.py +84 -0
  29. src/models/bert_token_spans.py +100 -0
  30. src/models/roberta_crf_token.py +66 -0
  31. src/models/roberta_multi_spans.py +82 -0
  32. src/models/roberta_token_spans.py +97 -0
  33. src/models/two_layer_nn.py +46 -0
  34. src/modules/__init__.py +0 -0
  35. src/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  36. src/modules/__pycache__/embeddings.cpython-38.pyc +0 -0
  37. src/modules/__pycache__/preprocessors.cpython-38.pyc +0 -0
  38. src/modules/__pycache__/tokenizers.cpython-38.pyc +0 -0
  39. src/modules/activations.py +6 -0
  40. src/modules/embeddings.py +37 -0
  41. src/modules/losses.py +6 -0
  42. src/modules/metrics.py +17 -0
  43. src/modules/optimizers.py +7 -0
  44. src/modules/preprocessors.py +112 -0
  45. src/modules/schedulers.py +14 -0
  46. src/modules/tokenizers.py +107 -0
  47. src/trainers/__init__.py +0 -0
  48. src/trainers/base_trainer.py +563 -0
  49. src/utils/__init__.py +0 -0
  50. src/utils/__pycache__/__init__.cpython-38.pyc +0 -0
src/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from src.datasets.toxic_spans_tokens import *
2
+ from src.datasets.toxic_spans_tokens_3cls import *
3
+ from src.datasets.toxic_spans_spans import *
4
+ from src.datasets.toxic_spans_tokens_spans import *
5
+ from src.datasets.toxic_spans_multi_spans import *
6
+ from src.datasets.toxic_spans_crf_tokens import *
7
+ from src.datasets.toxic_spans_crf_3cls_tokens import *
src/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (503 Bytes). View file
 
src/datasets/__pycache__/toxic_spans_crf_3cls_tokens.cpython-38.pyc ADDED
Binary file (2.99 kB). View file
 
src/datasets/__pycache__/toxic_spans_crf_tokens.cpython-38.pyc ADDED
Binary file (2.76 kB). View file
 
src/datasets/__pycache__/toxic_spans_multi_spans.cpython-38.pyc ADDED
Binary file (5.55 kB). View file
 
src/datasets/__pycache__/toxic_spans_spans.cpython-38.pyc ADDED
Binary file (5.2 kB). View file
 
src/datasets/__pycache__/toxic_spans_tokens.cpython-38.pyc ADDED
Binary file (2.35 kB). View file
 
src/datasets/__pycache__/toxic_spans_tokens_3cls.cpython-38.pyc ADDED
Binary file (2.59 kB). View file
 
src/datasets/__pycache__/toxic_spans_tokens_spans.cpython-38.pyc ADDED
Binary file (5.97 kB). View file
 
src/datasets/toxic_spans_crf_3cls_tokens.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+
6
+
7
+ @configmapper.map("datasets", "toxic_spans_crf_3cls_tokens")
8
+ class ToxicSpansCRF3ClsTokenDataset:
9
+ def __init__(self, config):
10
+ self.config = config
11
+ self.tokenizer = AutoTokenizer.from_pretrained(
12
+ self.config.model_checkpoint_name
13
+ )
14
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
15
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
16
+
17
+ self.tokenized_inputs = self.dataset.map(
18
+ self.tokenize_and_align_labels_for_train, batched=True
19
+ )
20
+ self.test_tokenized_inputs = self.test_dataset.map(
21
+ self.tokenize_for_test, batched=True
22
+ )
23
+
24
+ def tokenize_and_align_labels_for_train(self, examples):
25
+ tokenized_inputs = self.tokenizer(
26
+ examples["text"], **self.config.tokenizer_params
27
+ )
28
+
29
+ # tokenized_inputs["text"] = examples["text"]
30
+ example_spans = []
31
+ labels = []
32
+ prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
33
+ offsets_mapping = tokenized_inputs["offset_mapping"]
34
+
35
+ ## Wrong Code
36
+ # for i, offset_mapping in enumerate(offsets_mapping):
37
+ # j = 0
38
+ # while j < len(offset_mapping): # [tok1, tok2, tok3] [(0,5),(1,4),(5,7)]
39
+ # if tokenized_inputs["input_ids"][i][j] in [
40
+ # self.tokenizer.sep_token_id,
41
+ # self.tokenizer.pad_token_id,
42
+ # self.tokenizer.cls_token_id,
43
+ # ]:
44
+ # j = j + 1
45
+ # continue
46
+ # else:
47
+ # k = j + 1
48
+ # while self.tokenizer.convert_ids_to_tokens(
49
+ # tokenized_inputs["input_ids"][i][k]
50
+ # ).startswith("##"):
51
+ # offset_mapping[i][j][1] = offset_mapping[i][k][1]
52
+ # j = k
53
+
54
+ for i, offset_mapping in enumerate(offsets_mapping):
55
+ labels.append([])
56
+
57
+ spans = eval(examples["spans"][i])
58
+ Bs = eval(examples["Bs"][i])
59
+ Is = eval(examples["Is"][i])
60
+
61
+ example_spans.append(spans)
62
+ # cls_label = 2 ## DUMMY LABEL
63
+ cls_label = 3 ## DUMMY LABEL
64
+ for j, offsets in enumerate(offset_mapping):
65
+ if tokenized_inputs["input_ids"][i][j] in [
66
+ self.tokenizer.sep_token_id,
67
+ self.tokenizer.pad_token_id,
68
+ ]:
69
+ tokenized_inputs["attention_mask"][i][j] = 0
70
+
71
+ if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
72
+ labels[-1].append(cls_label)
73
+ prediction_mask[i][j] = 1
74
+
75
+ elif offsets[0] == offsets[1] and offsets[0] == 0:
76
+ # labels[-1].append(2) ## DUMMY
77
+ labels[-1].append(cls_label) ## DUMMY
78
+
79
+ else:
80
+ # toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
81
+ # ## If any part of the the token is in span, mark it as Toxic
82
+ # if (
83
+ # len(toxic_offsets) > 0
84
+ # and sum(toxic_offsets) / len(toxic_offsets) > 0.0
85
+ # ):
86
+ # labels[-1].append(1)
87
+ # else:
88
+ # labels[-1].append(0)
89
+ # prediction_mask[i][j] = 1
90
+
91
+ b_off = [x in Bs for x in range(offsets[0], offsets[1])]
92
+ b_off = sum(b_off)
93
+ i_off = [x in Is for x in range(offsets[0], offsets[1])]
94
+ i_off = sum(i_off)
95
+ # if len(b_off) == len(i_off) and len(i_off) == 0:
96
+ if b_off == 0 and i_off == 0:
97
+ labels[-1].append(0)
98
+ # elif len(b_off) >= len(i_off) == 1:
99
+ elif b_off >= i_off:
100
+ labels[-1].append(1)
101
+ # print(b_off)
102
+ # print(i_off)
103
+ # print(j)
104
+ else:
105
+ labels[-1].append(2)
106
+
107
+ tokenized_inputs["labels"] = labels
108
+ tokenized_inputs["prediction_mask"] = prediction_mask
109
+ return tokenized_inputs
110
+
111
+ def tokenize_for_test(self, examples):
112
+ tokenized_inputs = self.tokenizer(
113
+ examples["text"], **self.config.tokenizer_params
114
+ )
115
+ prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
116
+ labels = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
117
+
118
+ offsets_mapping = tokenized_inputs["offset_mapping"]
119
+
120
+ for i, offset_mapping in enumerate(offsets_mapping):
121
+ for j, offsets in enumerate(offset_mapping):
122
+ if tokenized_inputs["input_ids"][i][j] in [
123
+ self.tokenizer.sep_token_id,
124
+ self.tokenizer.pad_token_id,
125
+ ]:
126
+ tokenized_inputs["attention_mask"][i][j] = 0
127
+ else:
128
+ prediction_mask[i][j] = 1
129
+
130
+ tokenized_inputs["prediction_mask"] = prediction_mask
131
+ tokenized_inputs["labels"] = labels
132
+ return tokenized_inputs
src/datasets/toxic_spans_crf_tokens.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+
6
+
7
+ @configmapper.map("datasets", "toxic_spans_crf_tokens")
8
+ class ToxicSpansCRFTokenDataset:
9
+ def __init__(self, config):
10
+ self.config = config
11
+ self.tokenizer = AutoTokenizer.from_pretrained(
12
+ self.config.model_checkpoint_name
13
+ )
14
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
15
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
16
+
17
+ self.tokenized_inputs = self.dataset.map(
18
+ self.tokenize_and_align_labels_for_train, batched=True
19
+ )
20
+ self.test_tokenized_inputs = self.test_dataset.map(
21
+ self.tokenize_for_test, batched=True
22
+ )
23
+
24
+ def tokenize_and_align_labels_for_train(self, examples):
25
+ tokenized_inputs = self.tokenizer(
26
+ examples["text"], **self.config.tokenizer_params
27
+ )
28
+
29
+ # tokenized_inputs["text"] = examples["text"]
30
+ example_spans = []
31
+ labels = []
32
+ prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
33
+ offsets_mapping = tokenized_inputs["offset_mapping"]
34
+
35
+ ## Wrong Code
36
+ # for i, offset_mapping in enumerate(offsets_mapping):
37
+ # j = 0
38
+ # while j < len(offset_mapping): # [tok1, tok2, tok3] [(0,5),(1,4),(5,7)]
39
+ # if tokenized_inputs["input_ids"][i][j] in [
40
+ # self.tokenizer.sep_token_id,
41
+ # self.tokenizer.pad_token_id,
42
+ # self.tokenizer.cls_token_id,
43
+ # ]:
44
+ # j = j + 1
45
+ # continue
46
+ # else:
47
+ # k = j + 1
48
+ # while self.tokenizer.convert_ids_to_tokens(
49
+ # tokenized_inputs["input_ids"][i][k]
50
+ # ).startswith("##"):
51
+ # offset_mapping[i][j][1] = offset_mapping[i][k][1]
52
+ # j = k
53
+
54
+ for i, offset_mapping in enumerate(offsets_mapping):
55
+ labels.append([])
56
+
57
+ spans = eval(examples["spans"][i])
58
+ example_spans.append(spans)
59
+ cls_label = 2 ## DUMMY LABEL
60
+ for j, offsets in enumerate(offset_mapping):
61
+ if tokenized_inputs["input_ids"][i][j] in [
62
+ self.tokenizer.sep_token_id,
63
+ self.tokenizer.pad_token_id,
64
+ ]:
65
+ tokenized_inputs["attention_mask"][i][j] = 0
66
+
67
+ if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
68
+ labels[-1].append(cls_label)
69
+ prediction_mask[i][j] = 1
70
+
71
+ elif offsets[0] == offsets[1] and offsets[0] == 0:
72
+ labels[-1].append(2) ## DUMMY
73
+
74
+ else:
75
+ toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
76
+ ## If any part of the the token is in span, mark it as Toxic
77
+ if (
78
+ len(toxic_offsets) > 0
79
+ and sum(toxic_offsets) / len(toxic_offsets) > 0.0
80
+ ):
81
+ labels[-1].append(1)
82
+ else:
83
+ labels[-1].append(0)
84
+ prediction_mask[i][j] = 1
85
+
86
+ tokenized_inputs["labels"] = labels
87
+ tokenized_inputs["prediction_mask"] = prediction_mask
88
+ return tokenized_inputs
89
+
90
+ def tokenize_for_test(self, examples):
91
+ tokenized_inputs = self.tokenizer(
92
+ examples["text"], **self.config.tokenizer_params
93
+ )
94
+ prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
95
+ labels = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
96
+
97
+ offsets_mapping = tokenized_inputs["offset_mapping"]
98
+
99
+ for i, offset_mapping in enumerate(offsets_mapping):
100
+ for j, offsets in enumerate(offset_mapping):
101
+ if tokenized_inputs["input_ids"][i][j] in [
102
+ self.tokenizer.sep_token_id,
103
+ self.tokenizer.pad_token_id,
104
+ ]:
105
+ tokenized_inputs["attention_mask"][i][j] = 0
106
+ else:
107
+ prediction_mask[i][j] = 1
108
+
109
+ tokenized_inputs["prediction_mask"] = prediction_mask
110
+ tokenized_inputs["labels"] = labels
111
+ return tokenized_inputs
src/datasets/toxic_spans_multi_spans.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ import pandas as pd
4
+ from datasets import load_dataset, Dataset
5
+ from evaluation.fix_spans import _contiguous_ranges
6
+
7
+
8
+ @configmapper.map("datasets", "toxic_spans_multi_spans")
9
+ class ToxicSpansMultiSpansDataset:
10
+ def __init__(self, config):
11
+ self.config = config
12
+ self.tokenizer = AutoTokenizer.from_pretrained(
13
+ self.config.model_checkpoint_name
14
+ )
15
+
16
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
17
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
18
+
19
+ temp_key_train = list(self.dataset.keys())[0]
20
+ self.intermediate_dataset = self.dataset.map(
21
+ self.create_train_features,
22
+ batched=True,
23
+ batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
24
+ remove_columns=self.dataset[temp_key_train].column_names,
25
+ )
26
+
27
+ temp_key_test = list(self.test_dataset.keys())[0]
28
+ self.intermediate_test_dataset = self.test_dataset.map(
29
+ self.create_test_features,
30
+ batched=True,
31
+ batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
32
+ remove_columns=self.test_dataset[temp_key_test].column_names,
33
+ )
34
+
35
+ self.tokenized_inputs = self.intermediate_dataset.map(
36
+ self.prepare_train_features,
37
+ batched=True,
38
+ remove_columns=self.intermediate_dataset[temp_key_train].column_names,
39
+ )
40
+ self.test_tokenized_inputs = self.intermediate_test_dataset.map(
41
+ self.prepare_test_features,
42
+ batched=True,
43
+ remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
44
+ )
45
+
46
+ def create_train_features(self, examples):
47
+ features = {
48
+ "context": [],
49
+ "id": [],
50
+ "question": [],
51
+ "title": [],
52
+ "start_positions": [],
53
+ "end_positions": [],
54
+ }
55
+ id = 0
56
+ # print(examples)
57
+ for row_number in range(len(examples["text"])):
58
+ context = examples["text"][row_number]
59
+ question = "offense"
60
+ title = context.split(" ")[0]
61
+ start_positions = []
62
+ end_positions = []
63
+ span = eval(examples["spans"][row_number])
64
+ contiguous_spans = _contiguous_ranges(span)
65
+ for lst in contiguous_spans:
66
+ lst = list(lst)
67
+ dict_to_write = {}
68
+
69
+ start_positions.append(lst[0])
70
+ end_positions.append(lst[1])
71
+
72
+ features["context"].append(context)
73
+ features["id"].append(str(id))
74
+ features["question"].append(question)
75
+ features["title"].append(title)
76
+ features["start_positions"].append(start_positions)
77
+ features["end_positions"].append(end_positions)
78
+ id += 1
79
+
80
+ return features
81
+
82
+ def create_test_features(self, examples):
83
+ features = {"context": [], "id": [], "question": [], "title": []}
84
+ id = 0
85
+ for row_number in range(len(examples["text"])):
86
+ context = examples["text"][row_number]
87
+ question = "offense"
88
+ title = context.split(" ")[0]
89
+ features["context"].append(context)
90
+ features["id"].append(str(id))
91
+ features["question"].append(question)
92
+ features["title"].append(title)
93
+ id += 1
94
+ return features
95
+
96
+ def prepare_train_features(self, examples):
97
+ """Generate tokenized features from examples.
98
+
99
+ Args:
100
+ examples (dict): The examples to be tokenized.
101
+
102
+ Returns:
103
+ transformers.tokenization_utils_base.BatchEncoding:
104
+ The tokenized features/examples after processing.
105
+ """
106
+ # Tokenize our examples with truncation and padding, but keep the
107
+ # overflows using a stride. This results in one example possible
108
+ # giving several features when a context is long, each of those
109
+ # features having a context that overlaps a bit the context
110
+ # of the previous feature.
111
+ pad_on_right = self.tokenizer.padding_side == "right"
112
+ print("### Batch Tokenizing Examples ###")
113
+ tokenized_examples = self.tokenizer(
114
+ examples["question" if pad_on_right else "context"],
115
+ examples["context" if pad_on_right else "question"],
116
+ **dict(self.config.tokenizer_params),
117
+ )
118
+
119
+ # Since one example might give us several features if it has
120
+ # a long context, we need a map from a feature to
121
+ # its corresponding example. This key gives us just that.
122
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
123
+ # The offset mappings will give us a map from token to
124
+ # character position in the original context. This will
125
+ # help us compute the start_positions and end_positions.
126
+ offset_mapping = tokenized_examples.pop("offset_mapping")
127
+
128
+ # Let's label those examples!
129
+ tokenized_examples["start_positions"] = []
130
+ tokenized_examples["end_positions"] = []
131
+
132
+ for i, offsets in enumerate(offset_mapping):
133
+ # We will label impossible answers with the index of the CLS token.
134
+ input_ids = tokenized_examples["input_ids"][i]
135
+
136
+ # Grab the sequence corresponding to that example
137
+ # (to know what is the context and what is the question).
138
+ sequence_ids = tokenized_examples.sequence_ids(i)
139
+
140
+ # One example can give several spans, this is the index of
141
+ # the example containing this span of text.
142
+ sample_index = sample_mapping[i]
143
+ start_positions = examples["start_positions"][sample_index]
144
+ end_positions = examples["end_positions"][sample_index]
145
+
146
+ start_positions_token_wise = [0 for x in range(len(input_ids))]
147
+ end_positions_token_wise = [0 for x in range(len(input_ids))]
148
+ # If no answers are given, set the cls_index as answer.
149
+ if len(start_positions) != 0:
150
+ for position in range(len(start_positions)):
151
+ start_char = start_positions[position]
152
+ end_char = end_positions[position] + 1
153
+
154
+ # Start token index of the current span in the text.
155
+ token_start_index = 0
156
+ while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
157
+ token_start_index += 1
158
+
159
+ # End token index of the current span in the text.
160
+ token_end_index = len(input_ids) - 1
161
+ while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
162
+ token_end_index -= 1
163
+
164
+ # Detect if the answer is out of the span (in which case we continue).
165
+ if not (
166
+ offsets[token_start_index][0] <= start_char
167
+ and offsets[token_end_index][1] >= end_char
168
+ ):
169
+ continue
170
+ else:
171
+ # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
172
+ # Note: we could go after the last offset if the answer is the last word (edge case).
173
+ while (
174
+ token_start_index < len(offsets)
175
+ and offsets[token_start_index][0] <= start_char
176
+ ):
177
+ token_start_index += 1
178
+ start_positions_token_wise[token_start_index - 1] = 1
179
+ while offsets[token_end_index][1] >= end_char:
180
+ token_end_index -= 1
181
+ end_positions_token_wise[token_end_index + 1] = 1
182
+ tokenized_examples["start_positions"].append(start_positions_token_wise)
183
+ tokenized_examples["end_positions"].append(start_positions_token_wise)
184
+ return tokenized_examples
185
+
186
+ def prepare_test_features(self, examples):
187
+
188
+ """Generate tokenized validation features from examples.
189
+
190
+ Args:
191
+ examples (dict): The validation examples to be tokenized.
192
+
193
+ Returns:
194
+ transformers.tokenization_utils_base.BatchEncoding:
195
+ The tokenized features/examples for validation set after processing.
196
+ """
197
+
198
+ # Tokenize our examples with truncation and maybe
199
+ # padding, but keep the overflows using a stride.
200
+ # This results in one example possible giving several features
201
+ # when a context is long, each of those features having a
202
+ # context that overlaps a bit the context of the previous feature.
203
+ print("### Tokenizing Validation Examples")
204
+ pad_on_right = self.tokenizer.padding_side == "right"
205
+ tokenized_examples = self.tokenizer(
206
+ examples["question" if pad_on_right else "context"],
207
+ examples["context" if pad_on_right else "question"],
208
+ **dict(self.config.tokenizer_params),
209
+ )
210
+
211
+ # Since one example might give us several features if it has a long context,
212
+ # we need a map from a feature to its corresponding example. This key gives us just that.
213
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
214
+
215
+ # We keep the example_id that gave us this feature and we will store the offset mappings.
216
+ tokenized_examples["example_id"] = []
217
+
218
+ for i in range(len(tokenized_examples["input_ids"])):
219
+ # Grab the sequence corresponding to that example
220
+ # (to know what is the context and what is the question).
221
+ sequence_ids = tokenized_examples.sequence_ids(i)
222
+ context_index = 1 if pad_on_right else 0
223
+
224
+ # One example can give several spans,
225
+ # this is the index of the example containing this span of text.
226
+ sample_index = sample_mapping[i]
227
+ tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
228
+
229
+ # Set to None the offset_mapping that are not part
230
+ # of the context so it's easy to determine if a token
231
+ # position is part of the context or not.
232
+ tokenized_examples["offset_mapping"][i] = [
233
+ (o if sequence_ids[k] == context_index else None)
234
+ for k, o in enumerate(tokenized_examples["offset_mapping"][i])
235
+ ]
236
+
237
+ return tokenized_examples
src/datasets/toxic_spans_spans.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ import pandas as pd
4
+ from datasets import load_dataset, Dataset
5
+ from evaluation.fix_spans import _contiguous_ranges
6
+
7
+
8
+ @configmapper.map("datasets", "toxic_spans_spans")
9
+ class ToxicSpansSpansDataset:
10
+ def __init__(self, config):
11
+ # print("### ToxicSpansSpansDataset ###"); exit()
12
+ self.config = config
13
+ self.tokenizer = AutoTokenizer.from_pretrained(
14
+ self.config.model_checkpoint_name
15
+ )
16
+
17
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
18
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
19
+
20
+ temp_key_train = list(self.dataset.keys())[0]
21
+ self.intermediate_dataset = self.dataset.map(
22
+ self.create_train_features,
23
+ batched=True,
24
+ batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
25
+ remove_columns=self.dataset[temp_key_train].column_names,
26
+ )
27
+
28
+ temp_key_test = list(self.test_dataset.keys())[0]
29
+ self.intermediate_test_dataset = self.test_dataset.map(
30
+ self.create_test_features,
31
+ batched=True,
32
+ batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
33
+ remove_columns=self.test_dataset[temp_key_test].column_names,
34
+ )
35
+
36
+ self.tokenized_inputs = self.intermediate_dataset.map(
37
+ self.prepare_train_features,
38
+ batched=True,
39
+ remove_columns=self.intermediate_dataset[temp_key_train].column_names,
40
+ )
41
+ self.test_tokenized_inputs = self.intermediate_test_dataset.map(
42
+ self.prepare_test_features,
43
+ batched=True,
44
+ remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
45
+ )
46
+
47
+ def create_train_features(self, examples):
48
+ features = {"context": [], "id": [], "question": [], "title": []}
49
+ id = 0
50
+ # print(examples)
51
+ for row_number in range(len(examples["text"])):
52
+ context = examples["text"][row_number]
53
+ # question = "offense"
54
+ question = "ভুল"
55
+ title = context.split(" ")[0]
56
+ span = eval(examples["spans"][row_number])
57
+ contiguous_spans = _contiguous_ranges(span)
58
+ for lst in contiguous_spans:
59
+ lst = list(lst)
60
+ dict_to_write = {}
61
+
62
+ dict_to_write["answer_start"] = [lst[0]]
63
+ dict_to_write["text"] = [context[lst[0] : lst[-1] + 1]]
64
+ # print(dict_to_write)
65
+ if "answers" in features.keys():
66
+ features["answers"].append(dict_to_write)
67
+ else:
68
+ features["answers"] = [
69
+ dict_to_write,
70
+ ]
71
+ features["context"].append(context)
72
+ features["id"].append(str(id))
73
+ features["question"].append(question)
74
+ features["title"].append(title)
75
+ id += 1
76
+
77
+ return features
78
+
79
+ def create_test_features(self, examples):
80
+ features = {"context": [], "id": [], "question": [], "title": []}
81
+ id = 0
82
+ for row_number in range(len(examples["text"])):
83
+ context = examples["text"][row_number]
84
+ # question = "offense"
85
+ question = "ভুল"
86
+ title = context.split(" ")[0]
87
+ features["context"].append(context)
88
+ features["id"].append(str(id))
89
+ features["question"].append(question)
90
+ features["title"].append(title)
91
+ id += 1
92
+ return features
93
+
94
+ def prepare_train_features(self, examples):
95
+ """Generate tokenized features from examples.
96
+
97
+ Args:
98
+ examples (dict): The examples to be tokenized.
99
+
100
+ Returns:
101
+ transformers.tokenization_utils_base.BatchEncoding:
102
+ The tokenized features/examples after processing.
103
+ """
104
+ # Tokenize our examples with truncation and padding, but keep the
105
+ # overflows using a stride. This results in one example possible
106
+ # giving several features when a context is long, each of those
107
+ # features having a context that overlaps a bit the context
108
+ # of the previous feature.
109
+ pad_on_right = self.tokenizer.padding_side == "right"
110
+ print("### Batch Tokenizing Examples ###")
111
+ tokenized_examples = self.tokenizer(
112
+ examples["question" if pad_on_right else "context"],
113
+ examples["context" if pad_on_right else "question"],
114
+ **dict(self.config.tokenizer_params),
115
+ )
116
+
117
+ # Since one example might give us several features if it has
118
+ # a long context, we need a map from a feature to
119
+ # its corresponding example. This key gives us just that.
120
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
121
+ # The offset mappings will give us a map from token to
122
+ # character position in the original context. This will
123
+ # help us compute the start_positions and end_positions.
124
+ offset_mapping = tokenized_examples.pop("offset_mapping")
125
+
126
+ # Let's label those examples!
127
+ tokenized_examples["start_positions"] = []
128
+ tokenized_examples["end_positions"] = []
129
+
130
+ for i, offsets in enumerate(offset_mapping):
131
+ # We will label impossible answers with the index of the CLS token.
132
+ input_ids = tokenized_examples["input_ids"][i]
133
+ cls_index = input_ids.index(self.tokenizer.cls_token_id)
134
+
135
+ # Grab the sequence corresponding to that example
136
+ # (to know what is the context and what is the question).
137
+ sequence_ids = tokenized_examples.sequence_ids(i)
138
+
139
+ # One example can give several spans, this is the index of
140
+ # the example containing this span of text.
141
+ sample_index = sample_mapping[i]
142
+ answers = examples["answers"][sample_index]
143
+ # If no answers are given, set the cls_index as answer.
144
+ if len(answers["answer_start"]) == 0:
145
+ tokenized_examples["start_positions"].append(cls_index)
146
+ tokenized_examples["end_positions"].append(cls_index)
147
+ else:
148
+ # Start/end character index of the answer in the text.
149
+ start_char = answers["answer_start"][0]
150
+ end_char = start_char + len(answers["text"][0])
151
+
152
+ # Start token index of the current span in the text.
153
+ token_start_index = 0
154
+ while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
155
+ token_start_index += 1
156
+
157
+ # End token index of the current span in the text.
158
+ token_end_index = len(input_ids) - 1
159
+ while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
160
+ token_end_index -= 1
161
+
162
+ # Detect if the answer is out of the span
163
+ # (in which case this feature is labeled with the CLS index).
164
+ if not (
165
+ offsets[token_start_index][0] <= start_char
166
+ and offsets[token_end_index][1] >= end_char
167
+ ):
168
+ tokenized_examples["start_positions"].append(cls_index)
169
+ tokenized_examples["end_positions"].append(cls_index)
170
+ else:
171
+ # Otherwise move the token_start_index and
172
+ # stoken_end_index to the two ends of the answer.
173
+ # Note: we could go after the last offset
174
+ # if the answer is the last word (edge case).
175
+ while (
176
+ token_start_index < len(offsets)
177
+ and offsets[token_start_index][0] <= start_char
178
+ ):
179
+ token_start_index += 1
180
+ tokenized_examples["start_positions"].append(token_start_index - 1)
181
+ while offsets[token_end_index][1] >= end_char:
182
+ token_end_index -= 1
183
+ tokenized_examples["end_positions"].append(token_end_index + 1)
184
+
185
+ return tokenized_examples
186
+
187
+ def prepare_test_features(self, examples):
188
+
189
+ """Generate tokenized validation features from examples.
190
+
191
+ Args:
192
+ examples (dict): The validation examples to be tokenized.
193
+
194
+ Returns:
195
+ transformers.tokenization_utils_base.BatchEncoding:
196
+ The tokenized features/examples for validation set after processing.
197
+ """
198
+
199
+ # Tokenize our examples with truncation and maybe
200
+ # padding, but keep the overflows using a stride.
201
+ # This results in one example possible giving several features
202
+ # when a context is long, each of those features having a
203
+ # context that overlaps a bit the context of the previous feature.
204
+ print("### Tokenizing Validation Examples")
205
+ pad_on_right = self.tokenizer.padding_side == "right"
206
+ tokenized_examples = self.tokenizer(
207
+ examples["question" if pad_on_right else "context"],
208
+ examples["context" if pad_on_right else "question"],
209
+ **dict(self.config.tokenizer_params),
210
+ )
211
+
212
+ # Since one example might give us several features if it has a long context,
213
+ # we need a map from a feature to its corresponding example. This key gives us just that.
214
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
215
+
216
+ # We keep the example_id that gave us this feature and we will store the offset mappings.
217
+ tokenized_examples["example_id"] = []
218
+
219
+ for i in range(len(tokenized_examples["input_ids"])):
220
+ # Grab the sequence corresponding to that example
221
+ # (to know what is the context and what is the question).
222
+ sequence_ids = tokenized_examples.sequence_ids(i)
223
+ context_index = 1 if pad_on_right else 0
224
+
225
+ # One example can give several spans,
226
+ # this is the index of the example containing this span of text.
227
+ sample_index = sample_mapping[i]
228
+ tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
229
+
230
+ # Set to None the offset_mapping that are not part
231
+ # of the context so it's easy to determine if a token
232
+ # position is part of the context or not.
233
+ tokenized_examples["offset_mapping"][i] = [
234
+ (o if sequence_ids[k] == context_index else None)
235
+ for k, o in enumerate(tokenized_examples["offset_mapping"][i])
236
+ ]
237
+
238
+ return tokenized_examples
src/datasets/toxic_spans_tokens.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ from datasets import load_dataset
4
+
5
+ # import pdb
6
+
7
+ @configmapper.map("datasets", "toxic_spans_tokens")
8
+ class ToxicSpansTokenDataset:
9
+ def __init__(self, config):
10
+ # print("### ToxicSpansTokenDataset ###"); exit()
11
+ self.config = config
12
+ self.tokenizer = AutoTokenizer.from_pretrained(
13
+ self.config.model_checkpoint_name
14
+ )
15
+ # if self.config.model_checkpoint_name == "sberbank-ai/mGPT":
16
+ # self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
17
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
18
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
19
+
20
+ self.tokenized_inputs = self.dataset.map(
21
+ self.tokenize_and_align_labels_for_train, batched=True
22
+ )
23
+ self.test_tokenized_inputs = self.test_dataset.map(
24
+ self.tokenize_for_test, batched=True
25
+ )
26
+
27
+ def tokenize_and_align_labels_for_train(self, examples):
28
+
29
+ tokenized_inputs = self.tokenizer(
30
+ examples["text"], **self.config.tokenizer_params
31
+ )
32
+
33
+ # tokenized_inputs["text"] = examples["text"]
34
+ example_spans = []
35
+ labels = []
36
+
37
+ offsets_mapping = tokenized_inputs["offset_mapping"]
38
+ # pdb.set_trace()
39
+ for i, offset_mapping in enumerate(offsets_mapping):
40
+ labels.append([])
41
+
42
+ spans = eval(examples["spans"][i])
43
+ example_spans.append(spans)
44
+ if self.config.label_cls:
45
+ cls_label = (
46
+ 1
47
+ if (
48
+ len(examples["text"][i]) > 0
49
+ and len(spans) / len(examples["text"][i])
50
+ > self.config.cls_threshold
51
+ )
52
+ else 0
53
+ ) ## Make class label based on threshold
54
+ else:
55
+ cls_label = -100
56
+ for j, offsets in enumerate(offset_mapping):
57
+ if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
58
+ labels[-1].append(cls_label)
59
+ elif offsets[0] == offsets[1] and offsets[0] == 0: # All zero
60
+ labels[-1].append(-100) ## SPECIAL TOKEN
61
+ else:
62
+ toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
63
+ ## If any part of the the token is in span, mark it as Toxic
64
+ if (
65
+ len(toxic_offsets) > 0
66
+ and sum(toxic_offsets) / len(toxic_offsets)
67
+ > self.config.token_threshold
68
+ ):
69
+ labels[-1].append(1)
70
+ else:
71
+ labels[-1].append(0)
72
+
73
+ tokenized_inputs["labels"] = labels
74
+ # print("tokenized_inputs", tokenized_inputs); exit()
75
+ return tokenized_inputs
76
+
77
+ def tokenize_for_test(self, examples):
78
+ tokenized_inputs = self.tokenizer(
79
+ examples["text"], **self.config.tokenizer_params
80
+ )
81
+ return tokenized_inputs
src/datasets/toxic_spans_tokens_3cls.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ from datasets import load_dataset
4
+
5
+ import pdb
6
+
7
+ @configmapper.map("datasets", "toxic_spans_tokens_3cls")
8
+ class ToxicSpansToken3CLSDataset:
9
+ def __init__(self, config):
10
+ # print("### ToxicSpansTokenDataset ###"); exit()
11
+ self.config = config
12
+ self.tokenizer = AutoTokenizer.from_pretrained(
13
+ self.config.model_checkpoint_name
14
+ )
15
+ # if self.config.model_checkpoint_name == "sberbank-ai/mGPT":
16
+ # self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
17
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
18
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
19
+
20
+ self.tokenized_inputs = self.dataset.map(
21
+ self.tokenize_and_align_labels_for_train, batched=True
22
+ )
23
+ self.test_tokenized_inputs = self.test_dataset.map(
24
+ self.tokenize_for_test, batched=True
25
+ )
26
+
27
+ def tokenize_and_align_labels_for_train(self, examples):
28
+
29
+ tokenized_inputs = self.tokenizer(
30
+ examples["text"], **self.config.tokenizer_params
31
+ )
32
+
33
+ # tokenized_inputs["text"] = examples["text"]
34
+ example_spans = []
35
+ labels = []
36
+
37
+ offsets_mapping = tokenized_inputs["offset_mapping"]
38
+ # pdb.set_trace()
39
+ for i, offset_mapping in enumerate(offsets_mapping):
40
+ labels.append([])
41
+
42
+ spans = eval(examples["spans"][i])
43
+ Bs = eval(examples["Bs"][i])
44
+ Is = eval(examples["Is"][i])
45
+ example_spans.append(spans)
46
+ if self.config.label_cls:
47
+ cls_label = (
48
+ 1
49
+ if (
50
+ len(examples["text"][i]) > 0
51
+ and len(spans) / len(examples["text"][i])
52
+ > self.config.cls_threshold
53
+ )
54
+ else 0
55
+ ) ## Make class label based on threshold
56
+ else:
57
+ cls_label = -100
58
+ for j, offsets in enumerate(offset_mapping):
59
+ if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
60
+ labels[-1].append(cls_label)
61
+ elif offsets[0] == offsets[1] and offsets[0] == 0: # All zero
62
+ labels[-1].append(-100) ## SPECIAL TOKEN
63
+ else:
64
+ # toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
65
+ ## If any part of the the token is in span, mark it as Toxic
66
+ # if (
67
+ # len(toxic_offsets) > 0
68
+ # and sum(toxic_offsets) / len(toxic_offsets)
69
+ # > self.config.token_threshold
70
+ # ):
71
+ # labels[-1].append(1)
72
+ # else:
73
+ # labels[-1].append(0)
74
+ b_off = [x in Bs for x in range(offsets[0], offsets[1])]
75
+ b_off = sum(b_off)
76
+ i_off = [x in Is for x in range(offsets[0], offsets[1])]
77
+ i_off = sum(i_off)
78
+ # if len(b_off) == len(i_off) and len(i_off) == 0:
79
+ if b_off == 0 and i_off == 0:
80
+ labels[-1].append(0)
81
+ # elif len(b_off) >= len(i_off) == 1:
82
+ elif b_off >= i_off:
83
+ labels[-1].append(1)
84
+ # print(b_off)
85
+ # print(i_off)
86
+ # print(j)
87
+ else:
88
+ labels[-1].append(2)
89
+
90
+ # pdb.set_trace()
91
+
92
+
93
+
94
+ tokenized_inputs["labels"] = labels
95
+ # print("tokenized_inputs", tokenized_inputs); exit()
96
+ return tokenized_inputs
97
+
98
+ def tokenize_for_test(self, examples):
99
+ tokenized_inputs = self.tokenizer(
100
+ examples["text"], **self.config.tokenizer_params
101
+ )
102
+ return tokenized_inputs
src/datasets/toxic_spans_tokens_spans.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.mapper import configmapper
2
+ from transformers import AutoTokenizer
3
+ import pandas as pd
4
+ from datasets import load_dataset, Dataset
5
+ from evaluation.fix_spans import _contiguous_ranges
6
+
7
+
8
+ @configmapper.map("datasets", "toxic_spans_tokens_spans")
9
+ class ToxicSpansTokensSpansDataset:
10
+ def __init__(self, config):
11
+ self.config = config
12
+ self.tokenizer = AutoTokenizer.from_pretrained(
13
+ self.config.model_checkpoint_name
14
+ )
15
+
16
+ self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
17
+ self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
18
+
19
+ temp_key_train = list(self.dataset.keys())[0]
20
+ self.intermediate_dataset = self.dataset.map(
21
+ self.create_train_features,
22
+ batched=True,
23
+ batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
24
+ remove_columns=self.dataset[temp_key_train].column_names,
25
+ )
26
+
27
+ temp_key_test = list(self.test_dataset.keys())[0]
28
+ self.intermediate_test_dataset = self.test_dataset.map(
29
+ self.create_test_features,
30
+ batched=True,
31
+ batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
32
+ remove_columns=self.test_dataset[temp_key_test].column_names,
33
+ )
34
+
35
+ self.tokenized_inputs = self.intermediate_dataset.map(
36
+ self.prepare_train_features,
37
+ batched=True,
38
+ remove_columns=self.intermediate_dataset[temp_key_train].column_names,
39
+ )
40
+ self.test_tokenized_inputs = self.intermediate_test_dataset.map(
41
+ self.prepare_test_features,
42
+ batched=True,
43
+ remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
44
+ )
45
+
46
+ def create_train_features(self, examples):
47
+ features = {"context": [], "id": [], "question": [], "title": [], "spans": []}
48
+ id = 0
49
+ # print(examples)
50
+ for row_number in range(len(examples["text"])):
51
+ context = examples["text"][row_number]
52
+ question = "offense"
53
+ title = context.split(" ")[0]
54
+ span = eval(examples["spans"][row_number])
55
+ contiguous_spans = _contiguous_ranges(span)
56
+ for lst in contiguous_spans:
57
+ lst = list(lst)
58
+ dict_to_write = {}
59
+
60
+ dict_to_write["answer_start"] = [lst[0]]
61
+ dict_to_write["text"] = [context[lst[0] : lst[-1] + 1]]
62
+ # print(dict_to_write)
63
+ if "answers" in features.keys():
64
+ features["answers"].append(dict_to_write)
65
+ else:
66
+ features["answers"] = [
67
+ dict_to_write,
68
+ ]
69
+ features["context"].append(context)
70
+ features["id"].append(str(id))
71
+ features["question"].append(question)
72
+ features["title"].append(title)
73
+ features["spans"].append(span)
74
+ id += 1
75
+
76
+ return features
77
+
78
+ def create_test_features(self, examples):
79
+ features = {"context": [], "id": [], "question": [], "title": []}
80
+ id = 0
81
+ for row_number in range(len(examples["text"])):
82
+ context = examples["text"][row_number]
83
+ question = "offense"
84
+ title = context.split(" ")[0]
85
+ features["context"].append(context)
86
+ features["id"].append(str(id))
87
+ features["question"].append(question)
88
+ features["title"].append(title)
89
+ id += 1
90
+ return features
91
+
92
+ def prepare_train_features(self, examples):
93
+ """Generate tokenized features from examples.
94
+
95
+ Args:
96
+ examples (dict): The examples to be tokenized.
97
+
98
+ Returns:
99
+ transformers.tokenization_utils_base.BatchEncoding:
100
+ The tokenized features/examples after processing.
101
+ """
102
+ # Tokenize our examples with truncation and padding, but keep the
103
+ # overflows using a stride. This results in one example possible
104
+ # giving several features when a context is long, each of those
105
+ # features having a context that overlaps a bit the context
106
+ # of the previous feature.
107
+ pad_on_right = self.tokenizer.padding_side == "right"
108
+ print("### Batch Tokenizing Examples ###")
109
+ tokenized_examples = self.tokenizer(
110
+ examples["question" if pad_on_right else "context"],
111
+ examples["context" if pad_on_right else "question"],
112
+ **dict(self.config.tokenizer_params),
113
+ )
114
+
115
+ # Since one example might give us several features if it has
116
+ # a long context, we need a map from a feature to
117
+ # its corresponding example. This key gives us just that.
118
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
119
+ # The offset mappings will give us a map from token to
120
+ # character position in the original context. This will
121
+ # help us compute the start_positions and end_positions.
122
+ offset_mapping = tokenized_examples.pop("offset_mapping")
123
+
124
+ # Let's label those examples!
125
+ token_labels = []
126
+ tokenized_examples["start_positions"] = []
127
+ tokenized_examples["end_positions"] = []
128
+
129
+ for i, offsets in enumerate(offset_mapping):
130
+ # We will label impossible answers with the index of the CLS token.
131
+
132
+ token_labels.append([])
133
+ input_ids = tokenized_examples["input_ids"][i]
134
+ spans = examples["spans"][i]
135
+ if self.config.label_cls:
136
+ cls_label = (
137
+ 1
138
+ if (
139
+ len(examples["context"][i]) > 0
140
+ and len(spans) / len(examples["context"][i])
141
+ > self.config.cls_threshold
142
+ )
143
+ else 0
144
+ ) ## Make class label based on threshold
145
+ else:
146
+ cls_label = -100
147
+ for j, offset in enumerate(offsets):
148
+ if tokenized_examples["input_ids"][i][j] == self.tokenizer.cls_token_id:
149
+ token_labels[-1].append(cls_label)
150
+ elif offset[0] == offset[1] and offset[0] == 0:
151
+ token_labels[-1].append(-100) ## SPECIAL TOKEN
152
+ else:
153
+ toxic_offsets = [x in spans for x in range(offset[0], offset[1])]
154
+ ## If any part of the the token is in span, mark it as Toxic
155
+ if (
156
+ len(toxic_offsets) > 0
157
+ and sum(toxic_offsets) / len(toxic_offsets)
158
+ > self.config.token_threshold
159
+ ):
160
+ token_labels[-1].append(1)
161
+ else:
162
+ token_labels[-1].append(0)
163
+
164
+ cls_index = input_ids.index(self.tokenizer.cls_token_id)
165
+
166
+ # Grab the sequence corresponding to that example
167
+ # (to know what is the context and what is the question).
168
+ sequence_ids = tokenized_examples.sequence_ids(i)
169
+
170
+ # One example can give several spans, this is the index of
171
+ # the example containing this span of text.
172
+ sample_index = sample_mapping[i]
173
+ answers = examples["answers"][sample_index]
174
+ # If no answers are given, set the cls_index as answer.
175
+ if len(answers["answer_start"]) == 0:
176
+ tokenized_examples["start_positions"].append(cls_index)
177
+ tokenized_examples["end_positions"].append(cls_index)
178
+ else:
179
+ # Start/end character index of the answer in the text.
180
+ start_char = answers["answer_start"][0]
181
+ end_char = start_char + len(answers["text"][0])
182
+
183
+ # Start token index of the current span in the text.
184
+ token_start_index = 0
185
+ while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
186
+ token_start_index += 1
187
+
188
+ # End token index of the current span in the text.
189
+ token_end_index = len(input_ids) - 1
190
+ while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
191
+ token_end_index -= 1
192
+
193
+ # Detect if the answer is out of the span
194
+ # (in which case this feature is labeled with the CLS index).
195
+ if not (
196
+ offsets[token_start_index][0] <= start_char
197
+ and offsets[token_end_index][1] >= end_char
198
+ ):
199
+ tokenized_examples["start_positions"].append(cls_index)
200
+ tokenized_examples["end_positions"].append(cls_index)
201
+ else:
202
+ # Otherwise move the token_start_index and
203
+ # stoken_end_index to the two ends of the answer.
204
+ # Note: we could go after the last offset
205
+ # if the answer is the last word (edge case).
206
+ while (
207
+ token_start_index < len(offsets)
208
+ and offsets[token_start_index][0] <= start_char
209
+ ):
210
+ token_start_index += 1
211
+ tokenized_examples["start_positions"].append(token_start_index - 1)
212
+ while offsets[token_end_index][1] >= end_char:
213
+ token_end_index -= 1
214
+ tokenized_examples["end_positions"].append(token_end_index + 1)
215
+ tokenized_examples["labels"] = token_labels
216
+ return tokenized_examples
217
+
218
+ def prepare_test_features(self, examples):
219
+
220
+ """Generate tokenized validation features from examples.
221
+
222
+ Args:
223
+ examples (dict): The validation examples to be tokenized.
224
+
225
+ Returns:
226
+ transformers.tokenization_utils_base.BatchEncoding:
227
+ The tokenized features/examples for validation set after processing.
228
+ """
229
+
230
+ # Tokenize our examples with truncation and maybe
231
+ # padding, but keep the overflows using a stride.
232
+ # This results in one example possible giving several features
233
+ # when a context is long, each of those features having a
234
+ # context that overlaps a bit the context of the previous feature.
235
+ print("### Tokenizing Validation Examples")
236
+ pad_on_right = self.tokenizer.padding_side == "right"
237
+ tokenized_examples = self.tokenizer(
238
+ examples["question" if pad_on_right else "context"],
239
+ examples["context" if pad_on_right else "question"],
240
+ **dict(self.config.tokenizer_params),
241
+ )
242
+
243
+ # Since one example might give us several features if it has a long context,
244
+ # we need a map from a feature to its corresponding example. This key gives us just that.
245
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
246
+
247
+ # We keep the example_id that gave us this feature and we will store the offset mappings.
248
+ tokenized_examples["example_id"] = []
249
+
250
+ for i in range(len(tokenized_examples["input_ids"])):
251
+ # Grab the sequence corresponding to that example
252
+ # (to know what is the context and what is the question).
253
+ sequence_ids = tokenized_examples.sequence_ids(i)
254
+ context_index = 1 if pad_on_right else 0
255
+
256
+ # One example can give several spans,
257
+ # this is the index of the example containing this span of text.
258
+ sample_index = sample_mapping[i]
259
+ tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
260
+
261
+ # Set to None the offset_mapping that are not part
262
+ # of the context so it's easy to determine if a token
263
+ # position is part of the context or not.
264
+ tokenized_examples["offset_mapping"][i] = [
265
+ (o if sequence_ids[k] == context_index else None)
266
+ for k, o in enumerate(tokenized_examples["offset_mapping"][i])
267
+ ]
268
+
269
+ return tokenized_examples
src/models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from src.models.auto_models import *
2
+ from src.models.bert_token_spans import *
3
+ from src.models.roberta_token_spans import *
4
+ from src.models.bert_multi_spans import *
5
+ from src.models.roberta_multi_spans import *
6
+ from src.models.bert_crf_token import *
7
+ from src.models.roberta_crf_token import *
src/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (443 Bytes). View file
 
src/models/__pycache__/auto_models.cpython-38.pyc ADDED
Binary file (436 Bytes). View file
 
src/models/__pycache__/bert_crf_token.cpython-38.pyc ADDED
Binary file (1.68 kB). View file
 
src/models/__pycache__/bert_multi_spans.cpython-38.pyc ADDED
Binary file (1.72 kB). View file
 
src/models/__pycache__/bert_token_spans.cpython-38.pyc ADDED
Binary file (2.34 kB). View file
 
src/models/__pycache__/roberta_crf_token.cpython-38.pyc ADDED
Binary file (1.69 kB). View file
 
src/models/__pycache__/roberta_multi_spans.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
 
src/models/__pycache__/roberta_token_spans.cpython-38.pyc ADDED
Binary file (2.42 kB). View file
 
src/models/auto_models.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForTokenClassification, AutoModelForQuestionAnswering
2
+ from src.utils.mapper import configmapper
3
+
4
+ configmapper.map("models", "autotoken")(AutoModelForTokenClassification)
5
+ configmapper.map("models", "autotoken_3cls")(AutoModelForTokenClassification)
6
+ configmapper.map("models", "autospans")(AutoModelForQuestionAnswering)
src/models/bert_crf_token.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from transformers import BertForTokenClassification
3
+ from transformers import ElectraForTokenClassification
4
+ from torchcrf import CRF
5
+ from src.utils.mapper import configmapper
6
+ # import pdb
7
+
8
+
9
+ @configmapper.map("models", "bert_crf_token")
10
+ # class BertLSTMCRF(BertForTokenClassification):
11
+ class BertLSTMCRF(ElectraForTokenClassification):
12
+ def __init__(self, config, lstm_hidden_size, lstm_layers):
13
+ super().__init__(config)
14
+ # ipdb.set_trace()
15
+ self.lstm = torch.nn.LSTM(
16
+ input_size=config.hidden_size,
17
+ hidden_size=lstm_hidden_size,
18
+ num_layers=lstm_layers,
19
+ dropout=0.2,
20
+ batch_first=True,
21
+ bidirectional=True,
22
+ )
23
+ self.crf = CRF(config.num_labels, batch_first=True)
24
+
25
+ del self.classifier
26
+ self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
27
+
28
+ def forward(
29
+ self,
30
+ input_ids,
31
+ attention_mask=None,
32
+ token_type_ids=None,
33
+ labels=None,
34
+ prediction_mask=None,
35
+ ):
36
+ # pdb.set_trace()
37
+
38
+ # outputs = self.bert(
39
+ outputs = self.electra(
40
+ input_ids,
41
+ attention_mask,
42
+ token_type_ids,
43
+ output_hidden_states=True,
44
+ return_dict=False,
45
+ )
46
+ # seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
47
+
48
+ sequence_output = outputs[0] # outputs[1] is pooled output which is none.
49
+
50
+ sequence_output = self.dropout(sequence_output)
51
+
52
+ lstm_out, *_ = self.lstm(sequence_output)
53
+ sequence_output = self.dropout(lstm_out)
54
+
55
+ logits = self.classifier(sequence_output)
56
+
57
+ ## CRF
58
+ mask = prediction_mask
59
+ mask = mask[:, : logits.size(1)].contiguous()
60
+
61
+ # print(logits)
62
+
63
+ if labels is not None:
64
+ labels = labels[:, : logits.size(1)].contiguous()
65
+ loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
66
+
67
+ tags = self.crf.decode(logits, mask.bool())
68
+ # print(tags)
69
+ if labels is not None:
70
+ return (loss, logits, tags)
71
+ else:
72
+ return (logits, tags)
src/models/bert_multi_spans.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import BCEWithLogitsLoss
3
+ # from transformers import BertModel, BertPreTrainedModel
4
+ from transformers import ElectraPreTrainedModel, ElectraModel
5
+ from src.utils.mapper import configmapper
6
+
7
+
8
+ @configmapper.map("models", "bert_multi_spans")
9
+ # class BertForMultiSpans(BertPreTrainedModel):
10
+ class BertForMultiSpans(ElectraPreTrainedModel):
11
+ def __init__(self, config):
12
+ super(BertForMultiSpans, self).__init__(config)
13
+ # self.bert = BertModel(config)
14
+ self.bert = ElectraModel(config)
15
+ self.num_labels = config.num_labels
16
+
17
+ # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
18
+ # self.dropout = nn.Dropout(config.hidden_dropout_prob)
19
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
20
+ self.init_weights()
21
+
22
+ def forward(
23
+ self,
24
+ input_ids=None,
25
+ attention_mask=None,
26
+ token_type_ids=None,
27
+ position_ids=None,
28
+ head_mask=None,
29
+ inputs_embeds=None,
30
+ start_positions=None,
31
+ end_positions=None,
32
+ output_attentions=None,
33
+ output_hidden_states=None,
34
+ ):
35
+ outputs = self.bert(
36
+ input_ids,
37
+ attention_mask=attention_mask,
38
+ token_type_ids=token_type_ids,
39
+ position_ids=position_ids,
40
+ head_mask=head_mask,
41
+ inputs_embeds=inputs_embeds,
42
+ output_attentions=output_attentions,
43
+ output_hidden_states=output_hidden_states,
44
+ return_dict=None,
45
+ )
46
+
47
+ sequence_output = outputs[0]
48
+
49
+ logits = self.qa_outputs(sequence_output)
50
+ start_logits, end_logits = logits.split(1, dim=-1)
51
+ start_logits = start_logits.squeeze(-1)
52
+ end_logits = end_logits.squeeze(-1) # batch_size
53
+ # print(start_logits.shape, end_logits.shape, start_positions.shape, end_positions.shape)
54
+
55
+ total_loss = None
56
+ if (
57
+ start_positions is not None and end_positions is not None
58
+ ): # [batch_size/seq_length]
59
+ # # If we are on multi-GPU, split add a dimension
60
+ # if len(start_positions.size()) > 1:
61
+ # start_positions = start_positions.squeeze(-1)
62
+ # if len(end_positions.size()) > 1:
63
+ # end_positions = end_positions.squeeze(-1)
64
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
65
+ # ignored_index = start_logits.size(1)
66
+ # start_positions.clamp_(0, ignored_index)
67
+ # end_positions.clamp_(0, ignored_index)
68
+
69
+ # start_positions = start_logits.view()
70
+
71
+ loss_fct = BCEWithLogitsLoss()
72
+
73
+ start_loss = loss = loss_fct(
74
+ start_logits,
75
+ start_positions.float(),
76
+ )
77
+ end_loss = loss = loss_fct(
78
+ end_logits,
79
+ end_positions.float(),
80
+ )
81
+ total_loss = (start_loss + end_loss) / 2
82
+
83
+ output = (start_logits, end_logits) + outputs[2:]
84
+ return ((total_loss,) + output) if total_loss is not None else output
src/models/bert_token_spans.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.nn import CrossEntropyLoss
4
+ # from transformers import BertPreTrainedModel, BertModel
5
+ from transformers import ElectraPreTrainedModel, ElectraModel
6
+ from src.utils.mapper import configmapper
7
+
8
+
9
+ @configmapper.map("models", "bert_token_spans")
10
+ # class BertModelForTokenAndSpans(BertPreTrainedModel):
11
+ class BertModelForTokenAndSpans(ElectraPreTrainedModel):
12
+ def __init__(self, config, num_token_labels=2, num_qa_labels=2):
13
+ super(BertModelForTokenAndSpans, self).__init__(config)
14
+ # self.bert = BertModel(config)
15
+ self.bert = ElectraModel(config)
16
+ self.num_token_labels = num_token_labels
17
+ self.num_qa_labels = num_qa_labels
18
+ # print("Number of Token Labels: ", num_token_labels); exit()
19
+
20
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
21
+ self.classifier = nn.Linear(config.hidden_size, num_token_labels)
22
+ self.qa_outputs = nn.Linear(config.hidden_size, num_qa_labels)
23
+ self.init_weights()
24
+
25
+ def forward(
26
+ self,
27
+ input_ids=None,
28
+ attention_mask=None,
29
+ token_type_ids=None,
30
+ position_ids=None,
31
+ head_mask=None,
32
+ inputs_embeds=None,
33
+ start_positions=None,
34
+ end_positions=None,
35
+ labels=None, # Token Wise Labels
36
+ output_attentions=None,
37
+ output_hidden_states=None,
38
+ ):
39
+
40
+ outputs = self.bert(
41
+ input_ids,
42
+ attention_mask=attention_mask,
43
+ token_type_ids=token_type_ids,
44
+ position_ids=position_ids,
45
+ head_mask=head_mask,
46
+ inputs_embeds=inputs_embeds,
47
+ output_attentions=output_attentions,
48
+ output_hidden_states=output_hidden_states,
49
+ return_dict=None,
50
+ )
51
+
52
+ sequence_output = outputs[0]
53
+
54
+ qa_logits = self.qa_outputs(sequence_output)
55
+ start_logits, end_logits = qa_logits.split(1, dim=-1)
56
+ start_logits = start_logits.squeeze(-1)
57
+ end_logits = end_logits.squeeze(-1)
58
+
59
+ sequence_output = self.dropout(sequence_output)
60
+ token_logits = self.classifier(sequence_output)
61
+
62
+ total_loss = None
63
+ if (
64
+ start_positions is not None
65
+ and end_positions is not None
66
+ and labels is not None
67
+ ):
68
+ # If we are on multi-GPU, split add a dimension
69
+ if len(start_positions.size()) > 1:
70
+ start_positions = start_positions.squeeze(-1)
71
+ if len(end_positions.size()) > 1:
72
+ end_positions = end_positions.squeeze(-1)
73
+
74
+ ignored_index = start_logits.size(1)
75
+ start_positions.clamp_(0, ignored_index)
76
+ end_positions.clamp_(0, ignored_index)
77
+
78
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
79
+ start_loss = loss_fct(start_logits, start_positions)
80
+ end_loss = loss_fct(end_logits, end_positions)
81
+
82
+ loss_fct = CrossEntropyLoss()
83
+ if attention_mask is not None:
84
+ active_loss = attention_mask.view(-1) == 1
85
+ active_logits = token_logits.view(-1, self.num_token_labels)
86
+ active_labels = torch.where(
87
+ active_loss,
88
+ labels.view(-1),
89
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
90
+ )
91
+ token_loss = loss_fct(active_logits, active_labels)
92
+ else:
93
+ token_loss = loss_fct(
94
+ token_logits.view(-1, self.num_token_labels), labels.view(-1)
95
+ )
96
+
97
+ total_loss = (start_loss + end_loss) / 2 + token_loss
98
+
99
+ output = (start_logits, end_logits, token_logits) + outputs[2:]
100
+ return ((total_loss,) + output) if total_loss is not None else output
src/models/roberta_crf_token.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import RobertaForTokenClassification
3
+ from torchcrf import CRF
4
+ from src.utils.mapper import configmapper
5
+
6
+
7
+ @configmapper.map("models", "roberta_crf_token")
8
+ class RobertaLSTMCRF(RobertaForTokenClassification):
9
+ def __init__(self, config, lstm_hidden_size, lstm_layers):
10
+ super().__init__(config)
11
+ self.lstm = torch.nn.LSTM(
12
+ input_size=config.hidden_size,
13
+ hidden_size=lstm_hidden_size,
14
+ num_layers=lstm_layers,
15
+ dropout=0.2,
16
+ batch_first=True,
17
+ bidirectional=True,
18
+ )
19
+ self.crf = CRF(config.num_labels, batch_first=True)
20
+
21
+ del self.classifier
22
+ self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
23
+
24
+ def forward(
25
+ self,
26
+ input_ids,
27
+ attention_mask=None,
28
+ token_type_ids=None,
29
+ labels=None,
30
+ prediction_mask=None,
31
+ ):
32
+
33
+ outputs = self.roberta(
34
+ input_ids,
35
+ attention_mask,
36
+ token_type_ids,
37
+ output_hidden_states=True,
38
+ return_dict=False,
39
+ )
40
+ # seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
41
+
42
+ sequence_output = outputs[0] # outputs[1] is pooled output which is none.
43
+
44
+ sequence_output = self.dropout(sequence_output)
45
+
46
+ lstm_out, *_ = self.lstm(sequence_output)
47
+ sequence_output = self.dropout(lstm_out)
48
+
49
+ logits = self.classifier(sequence_output)
50
+
51
+ ## CRF
52
+ mask = prediction_mask
53
+ mask = mask[:, : logits.size(1)].contiguous()
54
+
55
+ # print(logits)
56
+
57
+ if labels is not None:
58
+ labels = labels[:, : logits.size(1)].contiguous()
59
+ loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
60
+
61
+ tags = self.crf.decode(logits, mask.bool())
62
+ # print(tags)
63
+ if labels is not None:
64
+ return (loss, logits, tags)
65
+ else:
66
+ return (logits, tags)
src/models/roberta_multi_spans.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import BCEWithLogitsLoss
3
+ from transformers import RobertaModel
4
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
5
+ from src.utils.mapper import configmapper
6
+
7
+
8
+ @configmapper.map("models", "roberta_multi_spans")
9
+ class RobertaForMultiSpans(RobertaPreTrainedModel):
10
+ def __init__(self, config):
11
+ super(RobertaForMultiSpans, self).__init__(config)
12
+ self.roberta = RobertaModel(config)
13
+ self.num_labels = config.num_labels
14
+
15
+ # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
16
+ # self.dropout = nn.Dropout(config.hidden_dropout_prob)
17
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
18
+ self.init_weights()
19
+
20
+ def forward(
21
+ self,
22
+ input_ids=None,
23
+ attention_mask=None,
24
+ token_type_ids=None,
25
+ position_ids=None,
26
+ head_mask=None,
27
+ inputs_embeds=None,
28
+ start_positions=None,
29
+ end_positions=None,
30
+ output_attentions=None,
31
+ output_hidden_states=None,
32
+ ):
33
+ outputs = self.roberta(
34
+ input_ids,
35
+ attention_mask=attention_mask,
36
+ token_type_ids=token_type_ids,
37
+ position_ids=position_ids,
38
+ head_mask=head_mask,
39
+ inputs_embeds=inputs_embeds,
40
+ output_attentions=output_attentions,
41
+ output_hidden_states=output_hidden_states,
42
+ return_dict=None,
43
+ )
44
+
45
+ sequence_output = outputs[0]
46
+
47
+ logits = self.qa_outputs(sequence_output)
48
+ start_logits, end_logits = logits.split(1, dim=-1)
49
+ start_logits = start_logits.squeeze(-1)
50
+ end_logits = end_logits.squeeze(-1) # batch_size
51
+ # print(start_logits.shape, end_logits.shape, start_positions.shape, end_positions.shape)
52
+
53
+ total_loss = None
54
+ if (
55
+ start_positions is not None and end_positions is not None
56
+ ): # [batch_size/seq_length]
57
+ # # If we are on multi-GPU, split add a dimension
58
+ # if len(start_positions.size()) > 1:
59
+ # start_positions = start_positions.squeeze(-1)
60
+ # if len(end_positions.size()) > 1:
61
+ # end_positions = end_positions.squeeze(-1)
62
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
63
+ # ignored_index = start_logits.size(1)
64
+ # start_positions.clamp_(0, ignored_index)
65
+ # end_positions.clamp_(0, ignored_index)
66
+
67
+ # start_positions = start_logits.view()
68
+
69
+ loss_fct = BCEWithLogitsLoss()
70
+
71
+ start_loss = loss = loss_fct(
72
+ start_logits,
73
+ start_positions.float(),
74
+ )
75
+ end_loss = loss = loss_fct(
76
+ end_logits,
77
+ end_positions.float(),
78
+ )
79
+ total_loss = (start_loss + end_loss) / 2
80
+
81
+ output = (start_logits, end_logits) + outputs[2:]
82
+ return ((total_loss,) + output) if total_loss is not None else output
src/models/roberta_token_spans.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers import RobertaModel
5
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
6
+ from src.utils.mapper import configmapper
7
+
8
+
9
+ @configmapper.map("models", "roberta_token_spans")
10
+ class RobertaModelForTokenAndSpans(RobertaPreTrainedModel):
11
+ def __init__(self, config, num_token_labels=2, num_qa_labels=2):
12
+ super(RobertaModelForTokenAndSpans, self).__init__(config)
13
+ self.roberta = RobertaModel(config)
14
+ self.num_token_labels = num_token_labels
15
+ self.num_qa_labels = num_qa_labels
16
+
17
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
18
+ self.classifier = nn.Linear(config.hidden_size, num_token_labels)
19
+ self.qa_outputs = nn.Linear(config.hidden_size, num_qa_labels)
20
+ self.init_weights()
21
+
22
+ def forward(
23
+ self,
24
+ input_ids=None,
25
+ attention_mask=None,
26
+ token_type_ids=None,
27
+ position_ids=None,
28
+ head_mask=None,
29
+ inputs_embeds=None,
30
+ start_positions=None,
31
+ end_positions=None,
32
+ labels=None, # Token Wise Labels
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ ):
36
+
37
+ outputs = self.roberta(
38
+ input_ids,
39
+ attention_mask=attention_mask,
40
+ token_type_ids=token_type_ids,
41
+ position_ids=position_ids,
42
+ head_mask=head_mask,
43
+ inputs_embeds=inputs_embeds,
44
+ output_attentions=output_attentions,
45
+ output_hidden_states=output_hidden_states,
46
+ return_dict=None,
47
+ )
48
+
49
+ sequence_output = outputs[0]
50
+
51
+ qa_logits = self.qa_outputs(sequence_output)
52
+ start_logits, end_logits = qa_logits.split(1, dim=-1)
53
+ start_logits = start_logits.squeeze(-1)
54
+ end_logits = end_logits.squeeze(-1)
55
+
56
+ sequence_output = self.dropout(sequence_output)
57
+ token_logits = self.classifier(sequence_output)
58
+
59
+ total_loss = None
60
+ if (
61
+ start_positions is not None
62
+ and end_positions is not None
63
+ and labels is not None
64
+ ):
65
+ # If we are on multi-GPU, split add a dimension
66
+ if len(start_positions.size()) > 1:
67
+ start_positions = start_positions.squeeze(-1)
68
+ if len(end_positions.size()) > 1:
69
+ end_positions = end_positions.squeeze(-1)
70
+
71
+ ignored_index = start_logits.size(1)
72
+ start_positions.clamp_(0, ignored_index)
73
+ end_positions.clamp_(0, ignored_index)
74
+
75
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
76
+ start_loss = loss_fct(start_logits, start_positions)
77
+ end_loss = loss_fct(end_logits, end_positions)
78
+
79
+ loss_fct = CrossEntropyLoss()
80
+ if attention_mask is not None:
81
+ active_loss = attention_mask.view(-1) == 1
82
+ active_logits = token_logits.view(-1, self.num_token_labels)
83
+ active_labels = torch.where(
84
+ active_loss,
85
+ labels.view(-1),
86
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
87
+ )
88
+ token_loss = loss_fct(active_logits, active_labels)
89
+ else:
90
+ token_loss = loss_fct(
91
+ token_logits.view(-1, self.num_token_labels), labels.view(-1)
92
+ )
93
+
94
+ total_loss = (start_loss + end_loss) / 2 + token_loss
95
+
96
+ output = (start_logits, end_logits, token_logits) + outputs[2:]
97
+ return ((total_loss,) + output) if total_loss is not None else output
src/models/two_layer_nn.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements a two layer Neural Network."""
2
+
3
+ from torch.nn import Module, Linear, ReLU
4
+ from src.utils.mapper import configmapper
5
+
6
+
7
+ @configmapper.map("models", "two_layer_nn")
8
+ class TwoLayerNN(Module):
9
+ """Implements two layer neural network.
10
+
11
+ Methods:
12
+ forward(x_input): Returns the output of the neural network.
13
+ """
14
+
15
+ def __init__(self, embedding, dims):
16
+ """Construct the two layer Neural Network.
17
+
18
+ This method is used to initialize the two layer neural network,
19
+ with a given embedding type and corresponding arguments.
20
+
21
+ Args:
22
+ embedding (torch.nn.Module): The embedding layer for the model.
23
+ dims (list): List of dimensions for the neural network, input to output.
24
+ """
25
+ super(TwoLayerNN, self).__init__()
26
+
27
+ self.embedding = embedding
28
+ self.linear1 = Linear(dims[0], dims[1])
29
+ self.relu = ReLU()
30
+ self.linear2 = Linear(dims[1], dims[2])
31
+
32
+ def forward(self, x_input):
33
+ """
34
+ Return the output of the neural network for an input.
35
+
36
+ Args:
37
+ x_input (torch.Tensor): The input tensor to the neural network.
38
+
39
+ Returns:
40
+ x_output (torch.Tensor): The output tensor for the neural network.
41
+ """
42
+ output = self.embedding(x_input)
43
+ output = self.linear1(output)
44
+ output = self.relu(output)
45
+ x_output = self.linear2(output)
46
+ return x_output
src/modules/__init__.py ADDED
File without changes
src/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (166 Bytes). View file
 
src/modules/__pycache__/embeddings.cpython-38.pyc ADDED
Binary file (1.67 kB). View file
 
src/modules/__pycache__/preprocessors.cpython-38.pyc ADDED
Binary file (3.42 kB). View file
 
src/modules/__pycache__/tokenizers.cpython-38.pyc ADDED
Binary file (4.87 kB). View file
 
src/modules/activations.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from src.utils.mapper import configmapper
3
+
4
+ configmapper.map("activations", "relu")(nn.ReLU)
5
+ configmapper.map("activations", "logsoftmax")(nn.LogSoftmax)
6
+ configmapper.map("activations", "softmax")(nn.Softmax)
src/modules/embeddings.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains various kinds of embeddings like Glove, BERT, etc."""
2
+
3
+ from torch.nn import Module, Embedding, Flatten
4
+ from src.utils.mapper import configmapper
5
+
6
+
7
+ @configmapper.map("embeddings", "glove")
8
+ class GloveEmbedding(Module):
9
+ """Implement Glove based Word Embedding."""
10
+
11
+ def __init__(self, embedding_matrix, padding_idx, static=True):
12
+ """Construct GloveEmbedding.
13
+
14
+ Args:
15
+ embedding_matrix (torch.Tensor): The matrix contrainining the embedding weights
16
+ padding_idx (int): The padding index in the tokenizer.
17
+ static (bool): Whether or not to freeze embeddings.
18
+ """
19
+ super(GloveEmbedding, self).__init__()
20
+ self.embedding = Embedding.from_pretrained(embedding_matrix)
21
+ self.embedding.padding_idx = padding_idx
22
+ if static:
23
+ self.embedding.weight.required_grad = False
24
+ self.flatten = Flatten(start_dim=1)
25
+
26
+ def forward(self, x_input):
27
+ """Pass the input through the embedding.
28
+
29
+ Args:
30
+ x_input (torch.Tensor): The numericalized tokenized input
31
+
32
+ Returns:
33
+ x_output (torch.Tensor): The output from the embedding
34
+ """
35
+ x_output = self.embedding(x_input)
36
+ x_output = self.flatten(x_output)
37
+ return x_output
src/modules/losses.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ "All criterion functions."
2
+ from torch.nn import MSELoss, CrossEntropyLoss
3
+ from src.utils.mapper import configmapper
4
+
5
+ configmapper.map("losses", "mse")(MSELoss)
6
+ configmapper.map("losses", "CrossEntropyLoss")(CrossEntropyLoss)
src/modules/metrics.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metrics."""
2
+ from sklearn.metrics import (
3
+ mean_squared_error,
4
+ f1_score,
5
+ precision_score,
6
+ recall_score,
7
+ roc_auc_score,
8
+ accuracy_score,
9
+ )
10
+ from src.utils.mapper import configmapper
11
+
12
+ configmapper.map("metrics", "sklearn_f1")(f1_score)
13
+ configmapper.map("metrics", "sklearn_p")(precision_score)
14
+ configmapper.map("metrics", "sklearn_r")(recall_score)
15
+ configmapper.map("metrics", "sklearn_roc")(roc_auc_score)
16
+ configmapper.map("metrics", "sklearn_acc")(accuracy_score)
17
+ configmapper.map("metrics", "sklearn_mse")(mean_squared_error)
src/modules/optimizers.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ " Method containing activation functions"
2
+ from torch.optim import Adam, AdamW, SGD
3
+ from src.utils.mapper import configmapper
4
+
5
+ configmapper.map("optimizers", "adam")(Adam)
6
+ configmapper.map("optimizers", "adam_w")(AdamW)
7
+ configmapper.map("optimizers", "sgd")(SGD)
src/modules/preprocessors.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.modules.tokenizers import *
2
+ from src.modules.embeddings import *
3
+ from src.utils.mapper import configmapper
4
+
5
+
6
+ class Preprocessor:
7
+ def preprocess(self):
8
+ pass
9
+
10
+
11
+ @configmapper.map("preprocessors", "glove")
12
+ class GlovePreprocessor(Preprocessor):
13
+ """GlovePreprocessor."""
14
+
15
+ def __init__(self, config):
16
+ """
17
+ Args:
18
+ config (src.utils.module.Config): configuration for preprocessor
19
+ """
20
+ super(GlovePreprocessor, self).__init__()
21
+ self.config = config
22
+ self.tokenizer = configmapper.get_object(
23
+ "tokenizers", self.config.main.preprocessor.tokenizer.name
24
+ )(**self.config.main.preprocessor.tokenizer.init_params.as_dict())
25
+ self.tokenizer_params = (
26
+ self.config.main.preprocessor.tokenizer.init_vector_params.as_dict()
27
+ )
28
+
29
+ self.tokenizer.initialize_vectors(**self.tokenizer_params)
30
+ self.embeddings = configmapper.get_object(
31
+ "embeddings", self.config.main.preprocessor.embedding.name
32
+ )(
33
+ self.tokenizer.text_field.vocab.vectors,
34
+ self.tokenizer.text_field.vocab.stoi[self.tokenizer.text_field.pad_token],
35
+ )
36
+
37
+ def preprocess(self, model_config, data_config):
38
+ train_dataset = configmapper.get_object("datasets", data_config.main.name)(
39
+ data_config.train, self.tokenizer
40
+ )
41
+ val_dataset = configmapper.get_object("datasets", data_config.main.name)(
42
+ data_config.val, self.tokenizer
43
+ )
44
+ model = configmapper.get_object("models", model_config.name)(
45
+ self.embeddings, **model_config.params.as_dict()
46
+ )
47
+
48
+ return model, train_dataset, val_dataset
49
+
50
+
51
+ @configmapper.map("preprocessors", "clozePreprocessor")
52
+ class ClozePreprocessor(Preprocessor):
53
+ """GlovePreprocessor."""
54
+
55
+ def __init__(self, config):
56
+ """
57
+ Args:
58
+ config (src.utils.module.Config): configuration for preprocessor
59
+ """
60
+ super(ClozePreprocessor, self).__init__()
61
+ self.config = config
62
+ self.tokenizer = configmapper.get_object(
63
+ "tokenizers", self.config.main.preprocessor.tokenizer.name
64
+ ).from_pretrained(
65
+ **self.config.main.preprocessor.tokenizer.init_params.as_dict()
66
+ )
67
+
68
+ def preprocess(self, model_config, data_config):
69
+ train_dataset = configmapper.get_object("datasets", data_config.main.name)(
70
+ data_config.train, self.tokenizer
71
+ )
72
+ val_dataset = configmapper.get_object("datasets", data_config.main.name)(
73
+ data_config.val, self.tokenizer
74
+ )
75
+ model = configmapper.get_object("models", model_config.name).from_pretrained(
76
+ **model_config.params.as_dict()
77
+ )
78
+
79
+ return model, train_dataset, val_dataset
80
+
81
+
82
+ @configmapper.map("preprocessors", "transformersConcretenessPreprocessor")
83
+ class TransformersConcretenessPreprocessor(Preprocessor):
84
+ """BertConcretenessPreprocessor."""
85
+
86
+ def __init__(self, config):
87
+ """
88
+ Args:
89
+ config (src.utils.module.Config): configuration for preprocessor
90
+ """
91
+ super(TransformersConcretenessPreprocessor, self).__init__()
92
+ self.config = config
93
+ self.tokenizer = configmapper.get_object(
94
+ "tokenizers", self.config.main.preprocessor.tokenizer.name
95
+ ).from_pretrained(
96
+ **self.config.main.preprocessor.tokenizer.init_params.as_dict()
97
+ )
98
+
99
+ def preprocess(self, model_config, data_config):
100
+
101
+ train_dataset = configmapper.get_object("datasets", data_config.main.name)(
102
+ data_config.train, self.tokenizer
103
+ )
104
+ val_dataset = configmapper.get_object("datasets", data_config.main.name)(
105
+ data_config.val, self.tokenizer
106
+ )
107
+
108
+ model = configmapper.get_object("models", model_config.name)(
109
+ **model_config.params.as_dict()
110
+ )
111
+
112
+ return model, train_dataset, val_dataset
src/modules/schedulers.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import (
2
+ StepLR,
3
+ CosineAnnealingLR,
4
+ ReduceLROnPlateau,
5
+ CyclicLR,
6
+ CosineAnnealingWarmRestarts,
7
+ )
8
+ from src.utils.mapper import configmapper
9
+
10
+ configmapper.map("schedulers", "step")(StepLR)
11
+ configmapper.map("schedulers", "cosineanneal")(CosineAnnealingLR)
12
+ configmapper.map("schedulers", "reduceplateau")(ReduceLROnPlateau)
13
+ configmapper.map("schedulers", "cyclic")(CyclicLR)
14
+ configmapper.map("schedulers", "cosineannealrestart")(CosineAnnealingWarmRestarts)
src/modules/tokenizers.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains tokenizers like GloveTokenizers and BERT Tokenizer."""
2
+
3
+ import torch
4
+ # from torchtext.vocab import GloVe
5
+ # from torchtext.data import Field, TabularDataset
6
+ from src.utils.mapper import configmapper
7
+ from transformers import AutoTokenizer
8
+
9
+
10
+ class Tokenizer:
11
+ """Abstract Class for Tokenizers."""
12
+
13
+ def tokenize(self):
14
+ """Abstract Method for tokenization."""
15
+
16
+
17
+ @configmapper.map("tokenizers", "glove")
18
+ class GloveTokenizer(Tokenizer):
19
+ """Implement GloveTokenizer for tokenizing text for Glove Embeddings.
20
+
21
+ Attributes:
22
+ embeddings (torchtext.vocab.Vectors): Loaded pre-trained embeddings.
23
+ text_field (torchtext.data.Field): Text_field for vector creation.
24
+
25
+ Methods:
26
+ __init__(self, name='840B', dim='300', cache='../embeddings/') : Constructor method
27
+ initialize_vectors(fix_length=4, tokenize='spacy', file_path="../data/imperceptibility
28
+ /Concreteness Ratings/train/forty.csv",
29
+ file_format='tsv', fields=None): Initialize vocab vectors based on data.
30
+
31
+ tokenize(x_input, **initializer_params): Tokenize given input and return the output.
32
+ """
33
+
34
+ def __init__(self, name="840B", dim="300", cache="../embeddings/"):
35
+ """Construct GloveTokenizer.
36
+
37
+ Args:
38
+ name (str): Name of the GloVe embedding file
39
+ dim (str): Dimensions of the Glove embedding file
40
+ cache (str): Path to the embeddings directory
41
+ """
42
+ super(GloveTokenizer, self).__init__()
43
+ self.embeddings = GloVe(name=name, dim=dim, cache=cache)
44
+ self.text_field = None
45
+
46
+ def initialize_vectors(
47
+ self,
48
+ fix_length=4,
49
+ tokenize="spacy",
50
+ tokenizer_file_paths=None,
51
+ file_format="tsv",
52
+ fields=None,
53
+ ):
54
+ """Initialize words/sequences based on GloVe embedding.
55
+
56
+ Args:
57
+ fields (list): The list containing the fields to be taken
58
+ and processed from the file (see documentation for
59
+ torchtext.data.TabularDataset)
60
+ fix_length (int): The length of the tokenized text,
61
+ padding or cropping is done accordingly
62
+ tokenize (function or string): Method to tokenize the data.
63
+ If 'spacy' uses spacy tokenizer,
64
+ else the specified method.
65
+ tokenizer_file_paths (list of str): The paths of the files containing the data
66
+ format (str): The format of the file : 'csv', 'tsv' or 'json'
67
+ """
68
+ text_field = Field(batch_first=True, fix_length=fix_length, tokenize=tokenize)
69
+ tab_dats = [
70
+ TabularDataset(
71
+ i, format=file_format, fields={k: (k, text_field) for k in fields}
72
+ )
73
+ for i in tokenizer_file_paths
74
+ ]
75
+ text_field.build_vocab(*tab_dats)
76
+ text_field.vocab.load_vectors(self.embeddings)
77
+ self.text_field = text_field
78
+
79
+ def tokenize(self, x_input, **init_vector__params):
80
+ """Tokenize given input based on initialized vectors.
81
+
82
+ Initialize the vectors with given parameters if not already initialized.
83
+
84
+ Args:
85
+ x_input (str): Unprocessed input text to be tokenized
86
+ **initializer_params (Keyword arguments): Parameters to initialize vectors
87
+
88
+ Returns:
89
+ x_output (str): Processed and tokenized text
90
+ """
91
+ if self.text_field is None:
92
+ self.initialize_vectors(**init_vector__params)
93
+ try:
94
+ x_output = torch.squeeze(
95
+ self.text_field.process([self.text_field.preprocess(x_input)])
96
+ )
97
+ except Exception as e:
98
+ print(x_input)
99
+ print(self.text_field.preprocess(x_input))
100
+ print(e)
101
+ return x_output
102
+
103
+
104
+ @configmapper.map("tokenizers", "AutoTokenizer")
105
+ class AutoTokenizer(AutoTokenizer):
106
+ def __init__(self, *args):
107
+ super(AutoTokenizer, self).__init__()
src/trainers/__init__.py ADDED
File without changes
src/trainers/base_trainer.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import torch
4
+ from src.modules.optimizers import *
5
+ from src.modules.embeddings import *
6
+ from src.modules.schedulers import *
7
+ from src.modules.tokenizers import *
8
+ from src.modules.metrics import *
9
+ from src.modules.losses import *
10
+ from src.utils.misc import *
11
+ from src.utils.logger import Logger
12
+ from src.utils.mapper import configmapper
13
+ from src.utils.configuration import Config
14
+
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+
18
+
19
+ @configmapper.map("trainers", "base")
20
+ class BaseTrainer:
21
+ def __init__(self, config):
22
+ self._config = config
23
+ self.metrics = {
24
+ configmapper.get_object("metrics", metric["type"]): metric["params"]
25
+ for metric in self._config.main_config.metrics
26
+ }
27
+ self.train_config = self._config.train
28
+ self.val_config = self._config.val
29
+ self.log_label = self.train_config.log.log_label
30
+ if self.train_config.log_and_val_interval is not None:
31
+ self.val_log_together = True
32
+ print("Logging with label: ", self.log_label)
33
+
34
+ def train(self, model, train_dataset, val_dataset=None, logger=None):
35
+ device = torch.device(self._config.main_config.device.name)
36
+ model.to(device)
37
+ optim_params = self.train_config.optimizer.params
38
+ if optim_params:
39
+ optimizer = configmapper.get_object(
40
+ "optimizers", self.train_config.optimizer.type
41
+ )(model.parameters(), **optim_params.as_dict())
42
+ else:
43
+ optimizer = configmapper.get_object(
44
+ "optimizers", self.train_config.optimizer.type
45
+ )(model.parameters())
46
+
47
+ if self.train_config.scheduler is not None:
48
+ scheduler_params = self.train_config.scheduler.params
49
+ if scheduler_params:
50
+ scheduler = configmapper.get_object(
51
+ "schedulers", self.train_config.scheduler.type
52
+ )(optimizer, **scheduler_params.as_dict())
53
+ else:
54
+ scheduler = configmapper.get_object(
55
+ "schedulers", self.train_config.scheduler.type
56
+ )(optimizer)
57
+
58
+ criterion_params = self.train_config.criterion.params
59
+ if criterion_params:
60
+ criterion = configmapper.get_object(
61
+ "losses", self.train_config.criterion.type
62
+ )(**criterion_params.as_dict())
63
+ else:
64
+ criterion = configmapper.get_object(
65
+ "losses", self.train_config.criterion.type
66
+ )()
67
+ if "custom_collate_fn" in dir(train_dataset):
68
+ train_loader = DataLoader(
69
+ dataset=train_dataset,
70
+ collate_fn=train_dataset.custom_collate_fn,
71
+ **self.train_config.loader_params.as_dict(),
72
+ )
73
+ else:
74
+ train_loader = DataLoader(
75
+ dataset=train_dataset, **self.train_config.loader_params.as_dict()
76
+ )
77
+ # train_logger = Logger(**self.train_config.log.logger_params.as_dict())
78
+
79
+ max_epochs = self.train_config.max_epochs
80
+ batch_size = self.train_config.loader_params.batch_size
81
+
82
+ if self.val_log_together:
83
+ val_interval = self.train_config.log_and_val_interval
84
+ log_interval = val_interval
85
+ else:
86
+ val_interval = self.train_config.val_interval
87
+ log_interval = self.train_config.log.log_interval
88
+
89
+ if logger is None:
90
+ train_logger = Logger(**self.train_config.log.logger_params.as_dict())
91
+ else:
92
+ train_logger = logger
93
+
94
+ train_log_values = self.train_config.log.values.as_dict()
95
+
96
+ best_score = (
97
+ -math.inf if self.train_config.save_on.desired == "max" else math.inf
98
+ )
99
+ save_on_score = self.train_config.save_on.score
100
+ best_step = -1
101
+ best_model = None
102
+
103
+ best_hparam_list = None
104
+ best_hparam_name_list = None
105
+ best_metrics_list = None
106
+ best_metrics_name_list = None
107
+
108
+ # print("\nTraining\n")
109
+ # print(max_steps)
110
+
111
+ global_step = 0
112
+ for epoch in range(1, max_epochs + 1):
113
+ print(
114
+ "Epoch: {}/{}, Global Step: {}".format(epoch, max_epochs, global_step)
115
+ )
116
+ train_loss = 0
117
+ val_loss = 0
118
+
119
+ if(self.train_config.label_type=='float'):
120
+ all_labels = torch.FloatTensor().to(device)
121
+ else:
122
+ all_labels = torch.LongTensor().to(device)
123
+
124
+ all_outputs = torch.Tensor().to(device)
125
+
126
+ train_scores = None
127
+ val_scores = None
128
+
129
+ pbar = tqdm(total=math.ceil(len(train_dataset) / batch_size))
130
+ pbar.set_description("Epoch " + str(epoch))
131
+
132
+ val_counter = 0
133
+
134
+ for step, batch in enumerate(train_loader):
135
+ model.train()
136
+ optimizer.zero_grad()
137
+ inputs, labels = batch
138
+
139
+ if(self.train_config.label_type=='float'): ##Specific to Float Type
140
+ labels = labels.float()
141
+
142
+ for key in inputs:
143
+ inputs[key] = inputs[key].to(device)
144
+ labels = labels.to(device)
145
+ outputs = model(inputs)
146
+ loss = criterion(torch.squeeze(outputs), labels)
147
+ loss.backward()
148
+
149
+ all_labels = torch.cat((all_labels, labels), 0)
150
+
151
+ if (self.train_config.label_type=='float'):
152
+ all_outputs = torch.cat((all_outputs, outputs), 0)
153
+ else:
154
+ all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
155
+
156
+
157
+ train_loss += loss.item()
158
+ optimizer.step()
159
+
160
+ if self.train_config.scheduler is not None:
161
+ if isinstance(scheduler, ReduceLROnPlateau):
162
+ scheduler.step(train_loss / (step + 1))
163
+ else:
164
+ scheduler.step()
165
+
166
+ # print(train_loss)
167
+ # print(step+1)
168
+
169
+ pbar.set_postfix_str(f"Train Loss: {train_loss /(step+1)}")
170
+ pbar.update(1)
171
+
172
+ global_step += 1
173
+
174
+ # Need to check if we want global_step or local_step
175
+
176
+ if val_dataset is not None and (global_step - 1) % val_interval == 0:
177
+ # print("\nEvaluating\n")
178
+ val_scores = self.val(
179
+ model,
180
+ val_dataset,
181
+ criterion,
182
+ device,
183
+ global_step,
184
+ train_logger,
185
+ train_log_values,
186
+ )
187
+
188
+ #save_flag = 0
189
+ if self.train_config.save_on is not None:
190
+
191
+ ## BEST SCORES UPDATING
192
+
193
+ train_scores = self.get_scores(
194
+ train_loss,
195
+ global_step,
196
+ self.train_config.criterion.type,
197
+ all_outputs,
198
+ all_labels,
199
+ )
200
+
201
+ best_score, best_step, save_flag = self.check_best(
202
+ val_scores, save_on_score, best_score, global_step
203
+ )
204
+
205
+ store_dict = {
206
+ "model_state_dict": model.state_dict(),
207
+ "best_step": best_step,
208
+ "best_score": best_score,
209
+ "save_on_score": save_on_score,
210
+ }
211
+
212
+ path = self.train_config.save_on.best_path.format(
213
+ self.log_label
214
+ )
215
+
216
+ self.save(store_dict, path, save_flag)
217
+
218
+ if save_flag and train_log_values["hparams"] is not None:
219
+ (
220
+ best_hparam_list,
221
+ best_hparam_name_list,
222
+ best_metrics_list,
223
+ best_metrics_name_list,
224
+ ) = self.update_hparams(
225
+ train_scores, val_scores, desc="best_val"
226
+ )
227
+ # pbar.close()
228
+ if (global_step - 1) % log_interval == 0:
229
+ # print("\nLogging\n")
230
+ train_loss_name = self.train_config.criterion.type
231
+ metric_list = [
232
+ metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
233
+ for metric in self.metrics
234
+ ]
235
+ metric_name_list = [
236
+ metric['type'] for metric in self._config.main_config.metrics
237
+ ]
238
+
239
+ train_scores = self.log(
240
+ train_loss / (step + 1),
241
+ train_loss_name,
242
+ metric_list,
243
+ metric_name_list,
244
+ train_logger,
245
+ train_log_values,
246
+ global_step,
247
+ append_text=self.train_config.append_text,
248
+ )
249
+ pbar.close()
250
+ if not os.path.exists(self.train_config.checkpoint.checkpoint_dir):
251
+ os.makedirs(self.train_config.checkpoint.checkpoint_dir)
252
+
253
+ if self.train_config.save_after_epoch:
254
+ store_dict = {
255
+ "model_state_dict": model.state_dict(),
256
+ }
257
+
258
+ path = f"{self.train_config.checkpoint.checkpoint_dir}_{str(self.train_config.log.log_label)}_{str(epoch)}.pth"
259
+
260
+ self.save(store_dict, path, save_flag=1)
261
+
262
+ if epoch == max_epochs:
263
+ # print("\nEvaluating\n")
264
+ val_scores = self.val(
265
+ model,
266
+ val_dataset,
267
+ criterion,
268
+ device,
269
+ global_step,
270
+ train_logger,
271
+ train_log_values,
272
+ )
273
+
274
+ # print("\nLogging\n")
275
+ train_loss_name = self.train_config.criterion.type
276
+ metric_list = [
277
+ metric(all_labels.cpu(), all_outputs.detach().cpu(),**self.metrics[metric])
278
+ for metric in self.metrics
279
+ ]
280
+ metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
281
+
282
+ train_scores = self.log(
283
+ train_loss / len(train_loader),
284
+ train_loss_name,
285
+ metric_list,
286
+ metric_name_list,
287
+ train_logger,
288
+ train_log_values,
289
+ global_step,
290
+ append_text=self.train_config.append_text,
291
+ )
292
+
293
+ if self.train_config.save_on is not None:
294
+
295
+ ## BEST SCORES UPDATING
296
+
297
+ train_scores = self.get_scores(
298
+ train_loss,
299
+ len(train_loader),
300
+ self.train_config.criterion.type,
301
+ all_outputs,
302
+ all_labels,
303
+ )
304
+
305
+ best_score, best_step, save_flag = self.check_best(
306
+ val_scores, save_on_score, best_score, global_step
307
+ )
308
+
309
+ store_dict = {
310
+ "model_state_dict": model.state_dict(),
311
+ "best_step": best_step,
312
+ "best_score": best_score,
313
+ "save_on_score": save_on_score,
314
+ }
315
+
316
+ path = self.train_config.save_on.best_path.format(self.log_label)
317
+
318
+ self.save(store_dict, path, save_flag)
319
+
320
+ if save_flag and train_log_values["hparams"] is not None:
321
+ (
322
+ best_hparam_list,
323
+ best_hparam_name_list,
324
+ best_metrics_list,
325
+ best_metrics_name_list,
326
+ ) = self.update_hparams(train_scores, val_scores, desc="best_val")
327
+
328
+ ## FINAL SCORES UPDATING + STORING
329
+ train_scores = self.get_scores(
330
+ train_loss,
331
+ len(train_loader),
332
+ self.train_config.criterion.type,
333
+ all_outputs,
334
+ all_labels,
335
+ )
336
+
337
+ store_dict = {
338
+ "model_state_dict": model.state_dict(),
339
+ "final_step": global_step,
340
+ "final_score": train_scores[save_on_score],
341
+ "save_on_score": save_on_score,
342
+ }
343
+
344
+ path = self.train_config.save_on.final_path.format(self.log_label)
345
+
346
+ self.save(store_dict, path, save_flag=1)
347
+ if train_log_values["hparams"] is not None:
348
+ (
349
+ final_hparam_list,
350
+ final_hparam_name_list,
351
+ final_metrics_list,
352
+ final_metrics_name_list,
353
+ ) = self.update_hparams(train_scores, val_scores, desc="final")
354
+ train_logger.save_hyperparams(
355
+ best_hparam_list,
356
+ best_hparam_name_list,
357
+ [int(self.log_label),] + best_metrics_list + final_metrics_list,
358
+ ["hparams/log_label",]
359
+ + best_metrics_name_list
360
+ + final_metrics_name_list,
361
+ )
362
+ #
363
+
364
+ ## Need to check if we want same loggers of different loggers for train and eval
365
+ ## Evaluate
366
+
367
+ def get_scores(self, loss, divisor, loss_name, all_outputs, all_labels):
368
+
369
+ avg_loss = loss / divisor
370
+
371
+ metric_list = [
372
+ metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
373
+ for metric in self.metrics
374
+ ]
375
+ metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
376
+
377
+ return dict(zip([loss_name,] + metric_name_list, [avg_loss,] + metric_list,))
378
+
379
+ def check_best(self, val_scores, save_on_score, best_score, global_step):
380
+ save_flag = 0
381
+ best_step = global_step
382
+ if self.train_config.save_on.desired == "min":
383
+ if val_scores[save_on_score] < best_score:
384
+ save_flag = 1
385
+ best_score = val_scores[save_on_score]
386
+ best_step = global_step
387
+ else:
388
+ if val_scores[save_on_score] > best_score:
389
+ save_flag = 1
390
+ best_score = val_scores[save_on_score]
391
+ best_step = global_step
392
+ return best_score, best_step, save_flag
393
+
394
+ def update_hparams(self, train_scores, val_scores, desc):
395
+ hparam_list = []
396
+ hparam_name_list = []
397
+ for hparam in self.train_config.log.values.hparams:
398
+ hparam_list.append(get_item_in_config(self._config, hparam["path"]))
399
+ if isinstance(hparam_list[-1], Config):
400
+ hparam_list[-1] = hparam_list[-1].as_dict()
401
+ hparam_name_list.append(hparam["name"])
402
+
403
+ val_keys, val_values = zip(*val_scores.items())
404
+ train_keys, train_values = zip(*train_scores.items())
405
+ val_keys = list(val_keys)
406
+ train_keys = list(train_keys)
407
+ val_values = list(val_values)
408
+ train_values = list(train_values)
409
+ for i, key in enumerate(val_keys):
410
+ val_keys[i] = f"hparams/{desc}_val_" + val_keys[i]
411
+ for i, key in enumerate(train_keys):
412
+ train_keys[i] = f"hparams/{desc}_train_" + train_keys[i]
413
+ # train_logger.save_hyperparams(hparam_list, hparam_name_list,train_values+val_values,train_keys+val_keys, )
414
+ return (
415
+ hparam_list,
416
+ hparam_name_list,
417
+ train_values + val_values,
418
+ train_keys + val_keys,
419
+ )
420
+
421
+ def save(self, store_dict, path, save_flag=0):
422
+ if save_flag:
423
+ dirs = "/".join(path.split("/")[:-1])
424
+ if not os.path.exists(dirs):
425
+ os.makedirs(dirs)
426
+ torch.save(store_dict, path)
427
+
428
+ def log(
429
+ self,
430
+ loss,
431
+ loss_name,
432
+ metric_list,
433
+ metric_name_list,
434
+ logger,
435
+ log_values,
436
+ global_step,
437
+ append_text,
438
+ ):
439
+
440
+ return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,))
441
+
442
+ loss_name = f"{append_text}_{self.log_label}_{loss_name}"
443
+ if log_values["loss"]:
444
+ logger.save_params(
445
+ [loss],
446
+ [loss_name],
447
+ combine=True,
448
+ combine_name="losses",
449
+ global_step=global_step,
450
+ )
451
+
452
+ for i in range(len(metric_name_list)):
453
+ metric_name_list[
454
+ i
455
+ ] = f"{append_text}_{self.log_label}_{metric_name_list[i]}"
456
+ if log_values["metrics"]:
457
+ logger.save_params(
458
+ metric_list,
459
+ metric_name_list,
460
+ combine=True,
461
+ combine_name="metrics",
462
+ global_step=global_step,
463
+ )
464
+ # print(hparams_list)
465
+ # print(hparam_name_list)
466
+
467
+ # for k,v in dict(zip([loss_name],[loss])).items():
468
+ # print(f"{k}:{v}")
469
+ # for k,v in dict(zip(metric_name_list,metric_list)).items():
470
+ # print(f"{k}:{v}")
471
+ return return_dic
472
+
473
+ def val(
474
+ self,
475
+ model,
476
+ dataset,
477
+ criterion,
478
+ device,
479
+ global_step,
480
+ train_logger=None,
481
+ train_log_values=None,
482
+ log=True,
483
+ ):
484
+ append_text = self.val_config.append_text
485
+ if train_logger is not None:
486
+ val_logger = train_logger
487
+ else:
488
+ val_logger = Logger(**self.val_config.log.logger_params.as_dict())
489
+
490
+ if train_log_values is not None:
491
+ val_log_values = train_log_values
492
+ else:
493
+ val_log_values = self.val_config.log.values.as_dict()
494
+ if "custom_collate_fn" in dir(dataset):
495
+ val_loader = DataLoader(
496
+ dataset=dataset,
497
+ collate_fn=dataset.custom_collate_fn,
498
+ **self.val_config.loader_params.as_dict(),
499
+ )
500
+ else:
501
+ val_loader = DataLoader(
502
+ dataset=dataset, **self.val_config.loader_params.as_dict()
503
+ )
504
+
505
+ all_outputs = torch.Tensor().to(device)
506
+ if(self.train_config.label_type=='float'):
507
+ all_labels = torch.FloatTensor().to(device)
508
+ else:
509
+ all_labels = torch.LongTensor().to(device)
510
+
511
+ batch_size = self.val_config.loader_params.batch_size
512
+
513
+ with torch.no_grad():
514
+ model.eval()
515
+ val_loss = 0
516
+ for j, batch in enumerate(val_loader):
517
+
518
+ inputs, labels = batch
519
+
520
+ if(self.train_config.label_type=='float'):
521
+ labels = labels.float()
522
+
523
+ for key in inputs:
524
+ inputs[key] = inputs[key].to(device)
525
+ labels = labels.to(device)
526
+
527
+ outputs = model(inputs)
528
+ loss = criterion(torch.squeeze(outputs), labels)
529
+ val_loss += loss.item()
530
+
531
+ all_labels = torch.cat((all_labels, labels), 0)
532
+
533
+ if (self.train_config.label_type=='float'):
534
+ all_outputs = torch.cat((all_outputs, outputs), 0)
535
+ else:
536
+ all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
537
+
538
+ val_loss = val_loss / len(val_loader)
539
+
540
+ val_loss_name = self.train_config.criterion.type
541
+
542
+ # print(all_outputs, all_labels)
543
+ metric_list = [
544
+ metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
545
+ for metric in self.metrics
546
+ ]
547
+ metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
548
+ return_dic = dict(
549
+ zip([val_loss_name,] + metric_name_list, [val_loss,] + metric_list,)
550
+ )
551
+ if log:
552
+ val_scores = self.log(
553
+ val_loss,
554
+ val_loss_name,
555
+ metric_list,
556
+ metric_name_list,
557
+ val_logger,
558
+ val_log_values,
559
+ global_step,
560
+ append_text,
561
+ )
562
+ return val_scores
563
+ return return_dic
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (164 Bytes). View file