Keras
legal
kevin110211 commited on
Commit
5d58b52
·
1 Parent(s): bc0481a

Upload 51 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. KTeleBERT/__pycache__/config.cpython-38.pyc +0 -0
  2. KTeleBERT/config.py +234 -0
  3. KTeleBERT/data_trans.py +56 -0
  4. KTeleBERT/get_chinese_ref.py +454 -0
  5. KTeleBERT/main.py +851 -0
  6. KTeleBERT/model/HWBert.py +146 -0
  7. KTeleBERT/model/KE_model.py +451 -0
  8. KTeleBERT/model/Numeric.py +218 -0
  9. KTeleBERT/model/OD_model.py +74 -0
  10. KTeleBERT/model/Tool_model.py +34 -0
  11. KTeleBERT/model/__init__.py +26 -0
  12. KTeleBERT/model/__pycache__/HWBert.cpython-38.pyc +0 -0
  13. KTeleBERT/model/__pycache__/KE_model.cpython-38.pyc +0 -0
  14. KTeleBERT/model/__pycache__/Numeric.cpython-38.pyc +0 -0
  15. KTeleBERT/model/__pycache__/OD_model.cpython-38.pyc +0 -0
  16. KTeleBERT/model/__pycache__/Tool_model.cpython-38.pyc +0 -0
  17. KTeleBERT/model/__pycache__/__init__.cpython-38.pyc +0 -0
  18. KTeleBERT/model/bert/__init__.py +201 -0
  19. KTeleBERT/model/bert/__pycache__/__init__.cpython-38.pyc +0 -0
  20. KTeleBERT/model/bert/__pycache__/configuration_bert.cpython-38.pyc +0 -0
  21. KTeleBERT/model/bert/__pycache__/modeling_bert.cpython-38.pyc +0 -0
  22. KTeleBERT/model/bert/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
  23. KTeleBERT/model/bert/configuration_bert.py +191 -0
  24. KTeleBERT/model/bert/modeling_bert.py +2010 -0
  25. KTeleBERT/model/bert/tokenization_bert.py +574 -0
  26. KTeleBERT/requirements.txt +10 -0
  27. KTeleBERT/run.sh +35 -0
  28. KTeleBERT/run_get_ref.sh +22 -0
  29. KTeleBERT/special_token_pre_emb.py +119 -0
  30. KTeleBERT/src/__init__.py +1 -0
  31. KTeleBERT/src/__pycache__/__init__.cpython-38.pyc +0 -0
  32. KTeleBERT/src/__pycache__/data.cpython-38.pyc +0 -0
  33. KTeleBERT/src/__pycache__/distributed_utils.cpython-38.pyc +0 -0
  34. KTeleBERT/src/__pycache__/utils.cpython-38.pyc +0 -0
  35. KTeleBERT/src/data.py +651 -0
  36. KTeleBERT/src/distributed_utils.py +79 -0
  37. KTeleBERT/src/utils.py +374 -0
  38. KTeleBERT/test.sh +12 -0
  39. KTeleBERT/torchlight/__init__.py +20 -0
  40. KTeleBERT/torchlight/__pycache__/__init__.cpython-38.pyc +0 -0
  41. KTeleBERT/torchlight/__pycache__/logger.cpython-38.pyc +0 -0
  42. KTeleBERT/torchlight/__pycache__/metric.cpython-38.pyc +0 -0
  43. KTeleBERT/torchlight/__pycache__/module.cpython-38.pyc +0 -0
  44. KTeleBERT/torchlight/__pycache__/utils.cpython-38.pyc +0 -0
  45. KTeleBERT/torchlight/__pycache__/vocab.cpython-38.pyc +0 -0
  46. KTeleBERT/torchlight/logger.py +147 -0
  47. KTeleBERT/torchlight/metric.py +121 -0
  48. KTeleBERT/torchlight/module.py +133 -0
  49. KTeleBERT/torchlight/utils.py +195 -0
  50. KTeleBERT/torchlight/vocab.py +137 -0
KTeleBERT/__pycache__/config.cpython-38.pyc ADDED
Binary file (7.01 kB). View file
 
KTeleBERT/config.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ from easydict import EasyDict as edict
6
+ import argparse
7
+
8
+
9
+ LAYER_MAPPING = {
10
+ 0: 'od_layer_0',
11
+ 1: 'od_layer_1',
12
+ 2: 'od_layer_2',
13
+ }
14
+
15
+
16
+ class cfg():
17
+ def __init__(self):
18
+ self.this_dir = osp.dirname(__file__)
19
+ # change
20
+ self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', ''))
21
+
22
+ # TODO: add some static variable (The frequency of change is low)
23
+
24
+ def get_args(self):
25
+ parser = argparse.ArgumentParser()
26
+ # ------------ base ------------
27
+ parser.add_argument('--train_strategy', default=1, type=int)
28
+ parser.add_argument('--batch_size', default=64, type=int)
29
+ parser.add_argument('--batch_size_ke', default=14, type=int)
30
+ parser.add_argument('--batch_size_od', default=8, type=int)
31
+ parser.add_argument('--batch_size_ad', default=32, type=int)
32
+
33
+ parser.add_argument('--epoch', default=15, type=int)
34
+ parser.add_argument("--save_model", default=1, type=int, choices=[0, 1])
35
+ # 用transformer的 save_pretrain 方式保存
36
+ parser.add_argument("--save_pretrain", default=0, type=int, choices=[0, 1])
37
+ parser.add_argument("--from_pretrain", default=0, type=int, choices=[0, 1])
38
+
39
+ # torthlight
40
+ parser.add_argument("--no_tensorboard", default=False, action="store_true")
41
+ parser.add_argument("--exp_name", default="huawei_exp", type=str, help="Experiment name")
42
+ parser.add_argument("--dump_path", default="dump/", type=str, help="Experiment dump path")
43
+ parser.add_argument("--exp_id", default="ke256_raekt_ernie2_bs20_p3_c3_5e-6", type=str, help="Experiment ID")
44
+ # or 3407
45
+ parser.add_argument("--random_seed", default=42, type=int)
46
+ # 数据参数
47
+ parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path")
48
+ parser.add_argument('--train_ratio', default=1, type=float, help='ratio for train/test')
49
+ parser.add_argument("--seq_data_name", default='Seq_data_base', type=str, help="seq_data 名字")
50
+ parser.add_argument("--kg_data_name", default='KG_data_base_rule', type=str, help="kg_data 名字")
51
+ parser.add_argument("--order_data_name", default='event_order_data', type=str, help="order_data 名字")
52
+ # TODO: add some dynamic variable
53
+ parser.add_argument("--model_name", default="MacBert", type=str, help="model name")
54
+
55
+ # ------------ 训练阶段 ------------
56
+ parser.add_argument("--scheduler", default="cos", type=str, choices=["linear", "cos"])
57
+ parser.add_argument("--optim", default="adamw", type=str)
58
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float)
59
+ parser.add_argument('--workers', type=int, default=8)
60
+ parser.add_argument('--accumulation_steps', type=int, default=6)
61
+ parser.add_argument('--accumulation_steps_ke', type=int, default=6)
62
+ parser.add_argument('--accumulation_steps_ad', type=int, default=6)
63
+ parser.add_argument('--accumulation_steps_od', type=int, default=6)
64
+ parser.add_argument("--train_together", default=0, type=int)
65
+
66
+ # 3e-5
67
+ parser.add_argument('--lr', type=float, default=1e-5)
68
+ # 逐层学习率衰减
69
+ parser.add_argument("--LLRD", default=0, type=int, choices=[0, 1])
70
+ parser.add_argument('--weight_decay', type=float, default=0.01)
71
+ parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
72
+ parser.add_argument('--scheduler_steps', type=int, default=None,
73
+ help='total number of step for the scheduler, if None then scheduler_total_step = total_step')
74
+ parser.add_argument('--eval_step', default=100, type=int, help='evaluate each n step')
75
+
76
+ # ------------ PLM ------------
77
+ parser.add_argument('--maxlength', type=int, default=200)
78
+ parser.add_argument('--mlm_probability', type=float, default=0.15)
79
+ parser.add_argument('--final_mlm_probability', type=float, default=0.4)
80
+ parser.add_argument('--mlm_probability_increase', type=str, default="curve", choices=["linear", "curve"])
81
+ parser.add_argument("--mask_stratege", default="rand", type=str, choices=["rand", "wwm", "domain"])
82
+ # 前n个epoch 用rand,后面用wwm. multi-stage knowledge masking strategy
83
+ parser.add_argument("--ernie_stratege", default=-1, type=int)
84
+ # 用mlm任务进行训练,默认使用chinese_ref且添加新的special word
85
+ parser.add_argument("--use_mlm_task", default=1, type=int, choices=[0, 1])
86
+ # 添加新的special word
87
+ parser.add_argument("--add_special_word", default=1, type=int, choices=[0, 1])
88
+ # freeze
89
+ parser.add_argument("--freeze_layer", default=0, type=int, choices=[0, 1, 2, 3, 4])
90
+ # 是否mask 特殊token
91
+ parser.add_argument("--special_token_mask", default=0, type=int, choices=[0, 1])
92
+ parser.add_argument("--emb_init", default=1, type=int, choices=[0, 1])
93
+ parser.add_argument("--cls_head_init", default=1, type=int, choices=[0, 1])
94
+ # 是否使用自适应权重
95
+ parser.add_argument("--use_awl", default=1, type=int, choices=[0, 1])
96
+ parser.add_argument("--mask_loss_scale", default=1.0, type=float)
97
+
98
+ # ------------ KGE ------------
99
+ parser.add_argument('--ke_norm', type=int, default=1)
100
+ parser.add_argument('--ke_dim', type=int, default=768)
101
+ parser.add_argument('--ke_margin', type=float, default=1.0)
102
+ parser.add_argument('--neg_num', type=int, default=10)
103
+ parser.add_argument('--adv_temp', type=float, default=1.0, help='The temperature of sampling in self-adversarial negative sampling.')
104
+ # 5e-4
105
+ parser.add_argument('--ke_lr', type=float, default=3e-5)
106
+ parser.add_argument('--only_ke_loss', type=int, default=0)
107
+
108
+ # ------------ 数值embedding相关 ------------
109
+ parser.add_argument('--use_NumEmb', type=int, default=1)
110
+ parser.add_argument("--contrastive_loss", default=1, type=int, choices=[0, 1])
111
+ parser.add_argument("--l_layers", default=2, type=int)
112
+ parser.add_argument('--use_kpi_loss', type=int, default=1)
113
+
114
+ # ------------ 测试阶段 ------------
115
+ parser.add_argument("--only_test", default=0, type=int, choices=[0, 1])
116
+ parser.add_argument("--mask_test", default=0, type=int, choices=[0, 1])
117
+ parser.add_argument("--embed_gen", default=0, type=int, choices=[0, 1])
118
+ parser.add_argument("--ke_test", default=0, type=int, choices=[0, 1])
119
+ # -1: 测全集
120
+ parser.add_argument("--ke_test_num", default=-1, type=int)
121
+ parser.add_argument("--path_gen", default="", type=str)
122
+
123
+ # ------------ 时序阶段 ------------
124
+ # 1:预训练
125
+ # 2:时序 finetune
126
+ # 3. 异常检测 finetune + 时序, 且是迭代的
127
+ # 是否加载od模型
128
+ parser.add_argument("--order_load", default=0, type=int)
129
+ parser.add_argument("--order_num", default=2, type=int)
130
+ parser.add_argument("--od_type", default='linear_cat', type=str, choices=['linear_cat', 'vertical_attention'])
131
+ parser.add_argument("--eps", default=0.2, type=float, help='label smoothing..')
132
+ parser.add_argument("--num_od_layer", default=0, type=int)
133
+ parser.add_argument("--plm_emb_type", default='cls', type=str, choices=['cls', 'last_avg'])
134
+ parser.add_argument("--order_test_name", default='', type=str)
135
+ parser.add_argument("--order_threshold", default=0.5, type=float)
136
+ # ------------ 并行训练 ------------
137
+ # 是否并行
138
+ parser.add_argument('--rank', type=int, default=0, help='rank to dist')
139
+ parser.add_argument('--dist', type=int, default=0, help='whether to dist')
140
+ # 不要改该参数,系统会自动分配
141
+ parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
142
+ # 开启的进程数(注意不是线程),不用设置该参数,会根据nproc_per_node自动设置
143
+ parser.add_argument('--world-size', default=4, type=int,
144
+ help='number of distributed processes')
145
+ parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
146
+ parser.add_argument("--local_rank", default=-1, type=int)
147
+ self.cfg = parser.parse_args()
148
+
149
+ def update_train_configs(self):
150
+ # add some constraint for parameters
151
+ # e.g. cannot save and test at the same time
152
+ # 修正默认参数
153
+ # TODO: 测试逻辑有问题需要修改
154
+ if len(self.cfg.order_test_name) > 0:
155
+ self.cfg.save_model = 0
156
+ if len(self.cfg.order_test_name) == 0:
157
+ self.cfg.train_ratio = min(0.8, self.cfg.train_ratio)
158
+ # 自适应载入文件名
159
+ else:
160
+ print("od test ... ")
161
+ self.cfg.train_strategy == 5
162
+ self.cfg.plm_emb_type = 'last_avg' if 'last_avg' in self.cfg.model_name else 'cls'
163
+ for key in LAYER_MAPPING.keys():
164
+ if LAYER_MAPPING[key] in self.cfg.model_name:
165
+ self.cfg.num_od_layer = key
166
+ self.cfg.order_test_name = osp.join('downstream_task', f'{self.cfg.order_test_name}')
167
+
168
+ if self.cfg.mask_test or self.cfg.embed_gen or self.cfg.ke_test or len(self.cfg.order_test_name) > 0:
169
+ assert len(self.cfg.model_name) > 0
170
+ self.cfg.only_test = 1
171
+ if self.cfg.only_test == 1:
172
+ self.save_model = 0
173
+ self.save_pretrain = 0
174
+
175
+ # TODO: update some dynamic variable
176
+ self.cfg.data_root = self.data_root
177
+ self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)
178
+ self.cfg.plm_path = osp.join(self.data_root, 'transformer')
179
+ self.cfg.dump_path = osp.join(self.cfg.data_path, self.cfg.dump_path)
180
+ # bs 控制尽量在32
181
+
182
+ # 自适应权重的数量
183
+ self.cfg.awl_num = 1
184
+ # ------------ 数值embedding相关 ------------
185
+ self.cfg.hidden_size = 768
186
+ self.cfg.num_attention_heads = 8
187
+ self.cfg.hidden_dropout_prob = 0.1
188
+ self.cfg.num_kpi = 304
189
+ self.cfg.specail_emb_path = None
190
+ if self.cfg.emb_init:
191
+ self.cfg.specail_emb_path = osp.join(self.cfg.data_path, 'added_vocab_embedding.pt')
192
+
193
+ # ------------- 多任务学习相关 -------------
194
+ # 四个阶段
195
+ self.cfg.mask_epoch, self.cfg.ke_epoch, self.cfg.ad_epoch, self.cfg.od_epoch = None, None, None, None
196
+ # 触发多任务 学习
197
+ if self.cfg.train_strategy > 1:
198
+ self.cfg.mask_epoch = [0, 1, 1, 1, 0]
199
+ self.cfg.ke_epoch = [4, 3, 2, 2, 0]
200
+ if self.cfg.only_ke_loss:
201
+ self.cfg.mask_epoch = [0, 0, 0, 0, 0]
202
+ self.cfg.epoch = sum(self.cfg.mask_epoch) + sum(self.cfg.ke_epoch)
203
+ if self.cfg.train_strategy > 2:
204
+ self.cfg.ad_epoch = [0, 6, 3, 1, 0]
205
+ self.cfg.epoch += sum(self.cfg.ad_epoch)
206
+ if self.cfg.train_strategy > 3 and not self.cfg.only_ke_loss:
207
+ self.cfg.od_epoch = [0, 0, 9, 1, 0]
208
+ # self.cfg.mask_epoch[3] = 1
209
+ self.cfg.epoch += sum(self.cfg.od_epoch)
210
+ self.cfg.epoch_matrix = []
211
+ for epochs in [self.cfg.mask_epoch, self.cfg.ke_epoch, self.cfg.ad_epoch, self.cfg.od_epoch]:
212
+ if epochs is not None:
213
+ self.cfg.epoch_matrix.append(epochs)
214
+ if self.cfg.train_together:
215
+ # loss 直接相加,训练epoch就是mask的epoch
216
+ self.cfg.epoch = sum(self.cfg.mask_epoch)
217
+ self.cfg.batch_size = int((self.cfg.batch_size - 16) / self.cfg.train_strategy)
218
+ self.cfg.batch_size_ke = int(self.cfg.batch_size_ke / self.cfg.train_strategy) - 2
219
+ self.cfg.batch_size_ad = int(self.cfg.batch_size_ad / self.cfg.train_strategy) - 1
220
+ self.cfg.batch_size_od = int(self.cfg.batch_size_od / self.cfg.train_strategy) - 1
221
+ self.cfg.accumulation_steps = (self.cfg.accumulation_steps - 1) * self.cfg.train_strategy
222
+
223
+ self.cfg.neg_num = max(min(self.cfg.neg_num, self.cfg.batch_size_ke - 3), 1)
224
+
225
+ self.cfg.accumulation_steps_dict = {0: self.cfg.accumulation_steps, 1: self.cfg.accumulation_steps_ke, 2: self.cfg.accumulation_steps_ad, 3: self.cfg.accumulation_steps_od}
226
+
227
+ # 使用数值embedding也必须添加新词因为位置信息和tokenizer绑定
228
+ if self.cfg.use_mlm_task or self.cfg.use_NumEmb:
229
+ assert self.cfg.add_special_word == 1
230
+
231
+ if self.cfg.use_NumEmb:
232
+ self.cfg.awl_num += 1
233
+
234
+ return self.cfg
KTeleBERT/data_trans.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import argparse
6
+ import pdb
7
+ import json
8
+
9
+ '''
10
+ 把数据合并
11
+ 同时抽取一部分需要的数据出来
12
+ '''
13
+
14
+ this_dir = osp.dirname(__file__)
15
+
16
+ data_root = osp.abspath(osp.join(this_dir, '..', '..', 'data', ''))
17
+
18
+ data_path = "huawei"
19
+ data_path = osp.join(data_root, data_path)
20
+
21
+
22
+ with open(osp.join(data_path, 'product_corpus.json'), "r") as f:
23
+ data_doc = json.load(f)
24
+
25
+ with open(osp.join(data_path, '831_alarm_serialize.json'), "r") as f:
26
+ data_alarm = json.load(f)
27
+ # kpi_info.json
28
+ with open(osp.join(data_path, '917_kpi_serialize_50_mn.json'), "r") as f:
29
+ data_kpi = json.load(f)
30
+
31
+
32
+ # 实体的序列化
33
+ with open(osp.join(data_path, '5GC_KB/database_entity_serialize.json'), "r") as f:
34
+ data_entity = json.load(f)
35
+
36
+ random.shuffle(data_kpi)
37
+ random.shuffle(data_doc)
38
+ random.shuffle(data_alarm)
39
+ random.shuffle(data_entity)
40
+ data = data_alarm + data_kpi + data_entity + data_doc
41
+ random.shuffle(data)
42
+
43
+ # 241527
44
+ pdb.set_trace()
45
+ with open(osp.join(data_path, 'Seq_data_large.json'), "w") as fp:
46
+ json.dump(data, fp, ensure_ascii=False)
47
+
48
+
49
+ # 三元组
50
+ with open(osp.join(data_path, '5GC_KB/database_triples.json'), "r") as f:
51
+ data = json.load(f)
52
+ random.shuffle(data)
53
+
54
+
55
+ with open(osp.join(data_path, 'KG_data_base.json'), "w") as fp:
56
+ json.dump(data, fp, ensure_ascii=False)
KTeleBERT/get_chinese_ref.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ from easydict import EasyDict as edict
6
+ import argparse
7
+ import pdb
8
+ import json
9
+ from model import BertTokenizer
10
+ from collections import Counter
11
+ from ltp import LTP
12
+ from tqdm import tqdm
13
+ from src.utils import add_special_token
14
+ from functools import reduce
15
+ from time import time
16
+ from numpy import mean
17
+ import math
18
+
19
+ from src.utils import Loss_log, time_trans
20
+ from collections import defaultdict
21
+
22
+
23
+ class cfg():
24
+ def __init__(self):
25
+ self.this_dir = osp.dirname(__file__)
26
+ # change
27
+ self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', ''))
28
+
29
+ def get_args(self):
30
+ parser = argparse.ArgumentParser()
31
+ # seq_data_name = "Seq_data_tiny_831"
32
+ parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path")
33
+ # TODO: freq 可以考虑 150
34
+ parser.add_argument("--freq", default=50, type=int, help="出现多少次的词认为是重要的")
35
+ parser.add_argument("--batch_size", default=100, type=int, help="分词的batch size")
36
+ parser.add_argument("--seq_data_name", default='Seq_data_large', type=str, help="seq_data 名字")
37
+ parser.add_argument("--deal_numeric", default=0, type=int, help="是否处理数值数据")
38
+
39
+ parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件")
40
+ self.cfg = parser.parse_args()
41
+
42
+ def update_train_configs(self):
43
+ # TODO: update some dynamic variable
44
+ self.cfg.data_root = self.data_root
45
+ self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)
46
+
47
+ return self.cfg
48
+
49
+
50
+ def refresh_data(ref, freq, special_token):
51
+ '''
52
+ 功能:在自定义的special token基础上基于最小出现频率得到更多新词分词系统的参考,作为wwm基础
53
+ 输入:
54
+ freq: 在(37万)语义词典中的最小出现频率(空格为分词)
55
+ special_token: 前面手工定义的特殊token(可能存在交集)
56
+ 输出:
57
+ add_words:在定义的最小出现频率基础上筛选出来的新词
58
+ '''
59
+ # 经常出现的sub token
60
+ seq_sub_data = [line.split() for line in ref]
61
+ all_data = []
62
+ for data in seq_sub_data:
63
+ all_data.extend(data)
64
+ sub_word_times = dict(Counter(all_data))
65
+ asub_word_time_order = sorted(sub_word_times.items(), key=lambda x: x[1], reverse=True)
66
+ # ('LST', 1218), ('RMV', 851), ('DSP', 821), ('ADD', 820), ('MOD', 590), ('SET', 406), ('AWS', 122)
67
+ # ADD、ACT、ALM-XXX、DEL、DSP、LST
68
+ add_words = []
69
+
70
+ for i in asub_word_time_order:
71
+ # 把出现频率很高的词加进来
72
+ if i[1] >= freq and len(i[0]) > 1 and len(i[0]) < 20 and not str.isdigit(i[0]):
73
+ add_words.append(i[0])
74
+ add_words.extend(special_token)
75
+ # 卡100阈值时是935个特殊token
76
+ print(f"[{len(add_words)}] special words will be added with frequency [{freq}]!")
77
+ return add_words
78
+
79
+
80
+ def cws(seq_data, add_words, batch_size):
81
+ '''
82
+ 功能:所有序列数据的输入转换成分词之后的结果
83
+ 输入:
84
+ seq_data:所有序列数据输入 e.g.['KPI异常下降', 'KPI异常上升']
85
+ add_words:添加的special words
86
+ batch_size:每次分多少句
87
+ 输出:
88
+ all_segment:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']]
89
+ data_size:输入/输出的序列数量(e.g. 2)
90
+ '''
91
+ # seq_data = seq_data.cuda()
92
+ print(f"loading...")
93
+ ltp = LTP("LTP/base2") # 默认加载 base2 模型
94
+ # ltp = LTP()
95
+ print(f"begin adding words ...")
96
+ # ltp.add_words(words=add_words, max_window=5) #4.1.5
97
+ ltp.add_words(words=add_words) # 4.2.8
98
+ ltp.to("cuda")
99
+ # for word in add_words:
100
+ # ltp.add_word(word)
101
+ print(f"{len(add_words)} special words are added!")
102
+
103
+ #
104
+ # for data in seq_data:
105
+ # output = ltp.pipeline([data], tasks=["cws"])
106
+ data_size = len(seq_data)
107
+ seq_data_cws = []
108
+ size = int(data_size / batch_size) + 1
109
+ b = 0
110
+ e = b + batch_size
111
+ # pdb.set_trace()
112
+
113
+ log = Loss_log()
114
+
115
+ with tqdm(total=size) as _tqdm:
116
+ # pdb.set_trace()
117
+ # log.time_init()
118
+ # pdb.set_trace()
119
+ error_data = []
120
+ for i in range(size):
121
+
122
+ output = []
123
+ try:
124
+ _output = ltp.pipeline(seq_data[b:e], tasks=["cws"])
125
+ for data in _output.cws:
126
+ try:
127
+ data_out = ltp.pipeline(data, tasks=["cws"])
128
+ # data_out_ = reduce(lambda x, y: x.extend(y) or x, data_out.cws)
129
+ data_out_ = []
130
+ for i in data_out.cws:
131
+ data_out_.extend([k.strip() for k in i])
132
+ output.append(data_out_)
133
+ except:
134
+ print(f"二阶段分词出错!范围是:[{b}]-[{e}]")
135
+ error_data.append(data)
136
+
137
+ # pdb.set_trace()
138
+ except:
139
+ print(f"第一阶段分词出错!范围是:[{b}]-[{e}]")
140
+ error_data.append(f"第一阶段分词出错!范围是:[{b}]-[{e}]")
141
+ # continue
142
+ seq_data_cws.extend(output)
143
+ b = e
144
+ e += batch_size
145
+
146
+ # 时间统计
147
+ if e >= data_size:
148
+ if b >= data_size:
149
+ break
150
+ e = data_size
151
+ _tqdm.set_description(f'from {b} to {e}:')
152
+ _tqdm.update(1)
153
+
154
+ print(f"过滤了{data_size - len(seq_data_cws)}个句子")
155
+
156
+ return seq_data_cws, data_size, error_data
157
+
158
+
159
+ def ltp_debug(ltp, op):
160
+ output = []
161
+ for data in op:
162
+ data_out = ltp.pipeline(data, tasks=["cws"])
163
+ # data_out_ = reduce(lambda x, y: x.extend(y) or x, data_out.cws)
164
+ data_out_ = []
165
+ for i in data_out.cws:
166
+ # 保留空格的话需要手动去除空格
167
+ data_out_.append(i[0].strip())
168
+ # 之前没有空格
169
+ # data_out_.extend(i)
170
+ output.append(data_out_)
171
+ return output
172
+
173
+
174
+ def deal_sub_words(subwords, special_token):
175
+ '''
176
+ 功能:把每个word的整体内,非首字符的部分加上 '##' 前缀, special_token 不应该被mask
177
+ '''
178
+ for i in range(len(subwords)):
179
+ if i == 0:
180
+ continue
181
+ if subwords[i] in special_token:
182
+ continue
183
+ if subwords[i].startswith("##"):
184
+ continue
185
+
186
+ subwords[i] = "##" + subwords[i]
187
+ return subwords
188
+
189
+
190
+ def generate_chinese_ref(seq_data_cws, special_token, deal_numeric, kpi_dic):
191
+ '''
192
+ 输入:
193
+ seq_data_cws:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']]
194
+ special_token:不应该被mask ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '|']
195
+ data_size:数据量 e.g. 2
196
+ 输出:
197
+ ww_return (whole word return):打标之后的chinese ref e.g. [['KPI', '异','##常', '下', '##降'], ['KPI', '异', '##常', '上', '##升']]
198
+ '''
199
+ # 定义全局set和逆字典统计哪些KPI最后没有被涉及
200
+ data_size = len(seq_data_cws)
201
+ kpi_static_set = set()
202
+ rev_kpi_dic = dict(zip(kpi_dic.values(), kpi_dic.keys()))
203
+ max_len = 0
204
+ sten_that_over_maxl = []
205
+ with tqdm(total=data_size) as _tqdm:
206
+ ww_return = []
207
+ ww_list = []
208
+ kpi_info = []
209
+ not_in_KPI = defaultdict(int)
210
+ for i in range(data_size):
211
+ _tqdm.set_description(f'checking...[{i}/{data_size}] max len: [{max_len}]')
212
+ orig = tokenizer.tokenize(" ".join(seq_data_cws[i]))
213
+
214
+ if deal_numeric:
215
+ # 得到元组信息,前两位是KPI下标范围
216
+ _kpi_info, kpi_type_list = extract_kpi(orig, kpi_dic, not_in_KPI)
217
+ kpi_info.append(_kpi_info)
218
+ kpi_static_set.update(kpi_type_list)
219
+
220
+ sub_total = []
221
+ ww_seq_tmp = []
222
+ ww_tmp = []
223
+ for sub_data in seq_data_cws[i]:
224
+ sub = tokenizer.tokenize(sub_data)
225
+ sub_total.extend(sub)
226
+ # 在whole word 里面添加#号
227
+ # 输入: ['异', '常']
228
+ ref_token = deal_sub_words(sub, special_token)
229
+ # 输出: ['异', '##常']
230
+ ww_seq_tmp.extend(ref_token)
231
+ ww_tmp.append(ref_token)
232
+
233
+ if sub_total != orig:
234
+ print("error in match... ")
235
+ if len(orig) > 512:
236
+ print("the lenth is over the max lenth")
237
+ pdb.set_trace()
238
+
239
+ # 变成[[...],[...],[...], ...]
240
+ # ww_return.append(ww_tmp)
241
+ sz_ww_seq = len(ww_seq_tmp)
242
+ # 求最大长度
243
+ max_len = sz_ww_seq if sz_ww_seq > max_len else max_len
244
+ if sz_ww_seq > 500:
245
+ sten_that_over_maxl.append((ww_seq_tmp, sz_ww_seq))
246
+
247
+ assert len(sub_total) == sz_ww_seq
248
+ ww_return.append(ww_seq_tmp)
249
+ ww_list.append(ww_tmp)
250
+ # pdb.set_trace()
251
+ _tqdm.update(1)
252
+ # pdb.set_trace()
253
+ if deal_numeric:
254
+ in_kpi = []
255
+ # pdb.set_trace()
256
+ for key in rev_kpi_dic.keys():
257
+ if key in kpi_static_set:
258
+ in_kpi.append(rev_kpi_dic[key])
259
+ if len(in_kpi) < len(rev_kpi_dic):
260
+ print(f"[{len(in_kpi)}] KPI are covered by data: {in_kpi}")
261
+ print(f" [{len(not_in_KPI)}] KPI无法匹配{not_in_KPI}")
262
+ else:
263
+ print("all KPI are covered!")
264
+ return ww_return, kpi_info, sten_that_over_maxl
265
+
266
+
267
+ def extract_num(seq_data_cws):
268
+ '''
269
+ 功能:把序列中的数值信息提取出来
270
+ 同时过滤 nan 数值
271
+ '''
272
+ num_ref = []
273
+ seq_data_cws_new = []
274
+ for j in range(len(seq_data_cws)):
275
+ num_index = [i for i, x in enumerate(seq_data_cws[j]) if x == '[NUM]']
276
+ # kpi_score = [float(seq_data_cws[i][index+1]) for index in num_index]
277
+ kpi_score = []
278
+ flag = 1
279
+ for index in num_index:
280
+ # if math.isnan(tmp):
281
+ # pdb.set_trace()
282
+ try:
283
+ tmp = float(seq_data_cws[j][index + 1])
284
+ except:
285
+ # pdb.set_trace()
286
+ flag = 0
287
+ continue
288
+ if math.isnan(tmp):
289
+ flag = 0
290
+ else:
291
+ kpi_score.append(tmp)
292
+
293
+ if len(num_index) > 0:
294
+ for index in reversed(num_index):
295
+ seq_data_cws[j].pop(index + 1)
296
+ if flag == 1:
297
+ num_ref.append(kpi_score)
298
+ seq_data_cws_new.append(seq_data_cws[j])
299
+ return seq_data_cws_new, num_ref
300
+
301
+
302
+ def extract_kpi(token_data, kpi_dic, not_in_KPI):
303
+ '''
304
+ 功能:把序列中的[KPI]下标范围,[NUM]下标提取出来
305
+ 输出格式: [(1,2,4),(5,6,7)]
306
+ '''
307
+ kpi_and_num_info = []
308
+ kpi_type = []
309
+ kpi_index = [i for i, x in enumerate(token_data) if x.lower() == '[kpi]']
310
+ num_index = [i for i, x in enumerate(token_data) if x.lower() == '[num]']
311
+ sz = len(kpi_index)
312
+ assert sz == len(num_index)
313
+ for i in range(sz):
314
+ # (kpi 开始,kpi 结束,NUM token位置)
315
+ # DONE: 添加KPI的类别
316
+ kpi_name = ''.join(token_data[kpi_index[i] + 1: num_index[i] - 1])
317
+ kpi_name_clear = kpi_name.replace('##', '')
318
+
319
+ if kpi_name in kpi_dic:
320
+ kpi_id = int(kpi_dic[kpi_name])
321
+ elif kpi_name_clear in kpi_dic:
322
+ kpi_id = int(kpi_dic[kpi_name_clear])
323
+ elif kpi_name_clear in not_in_KPI:
324
+ kpi_id = -1
325
+ not_in_KPI[kpi_name_clear] += 1
326
+ else:
327
+ # 只打印一次
328
+ not_in_KPI[kpi_name_clear] += 1
329
+ kpi_id = -1
330
+ # print(f"{kpi_name_clear} not in KPI dict")
331
+
332
+ kpi_info = [kpi_index[i] + 1, num_index[i] - 2, num_index[i], kpi_id]
333
+ kpi_and_num_info.append(kpi_info)
334
+ kpi_type.append(kpi_id)
335
+ # pdb.set_trace()
336
+
337
+ return kpi_and_num_info, kpi_type
338
+
339
+
340
+ def kpi_combine(kpi_info, num_ref):
341
+ sz = len(kpi_info)
342
+ assert sz == len(num_ref)
343
+ for i in range(sz):
344
+ for j in range(len(kpi_info[i])):
345
+ kpi_info[i][j].append(num_ref[i][j])
346
+ # pdb.set_trace()
347
+ return kpi_info
348
+
349
+ # 所有字母小写
350
+
351
+
352
+ def kpi_lower_update(kpi_dic):
353
+ new_dic = {}
354
+ for key in kpi_dic:
355
+ kk = key.lower().split()
356
+ kk = ''.join(kk).strip()
357
+ new_dic[kk] = kpi_dic[key]
358
+ return new_dic
359
+
360
+
361
+ if __name__ == '__main__':
362
+ '''
363
+ 功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据)
364
+ '''
365
+ cfg = cfg()
366
+ cfg.get_args()
367
+ cfgs = cfg.update_train_configs()
368
+
369
+ # 路径指定
370
+ domain_file_path = osp.join(cfgs.data_path, 'special_vocab.txt')
371
+ with open(domain_file_path, encoding="utf-8") as f:
372
+ ref = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
373
+ tokenizer = BertTokenizer.from_pretrained(osp.join(cfgs.data_root, 'transformer', 'MacBert'), do_lower_case=True)
374
+ seq_data_name = cfgs.seq_data_name
375
+ with open(osp.join(cfgs.data_path, f'{seq_data_name}.json'), "r") as fp:
376
+ seq_data = json.load(fp)
377
+ kpi_dic_name = 'kpi2id'
378
+ with open(osp.join(cfgs.data_path, f'{kpi_dic_name}.json'), "r") as fp:
379
+ kpi_dic = json.load(fp)
380
+ kpi_dic = kpi_lower_update(kpi_dic)
381
+ # 供测试
382
+ random.shuffle(seq_data)
383
+ # seq_data = seq_data[:500]
384
+ print(f"tokenizer size before: {len(tokenizer)}")
385
+ tokenizer, special_token, norm_token = add_special_token(tokenizer)
386
+ special_token = special_token + norm_token
387
+
388
+ print(f"tokenizer size after: {len(tokenizer)}")
389
+ print('------------------------ refresh data --------------------------------')
390
+ add_words = refresh_data(ref, cfgs.freq, special_token)
391
+
392
+ if not cfgs.read_cws:
393
+ print('------------------------ cws ----------------------------------')
394
+ seq_data_cws, data_size, error_data = cws(seq_data, add_words, cfgs.batch_size)
395
+ print(f'batch size is {cfgs.batch_size}')
396
+ if len(error_data) > 0:
397
+ with open(osp.join(cfgs.data_path, f'{seq_data_name}_error.json'), "w") as fp:
398
+ json.dump(error_data, fp, ensure_ascii=False)
399
+ save_path_cws_orig = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json')
400
+ print("get the new training data! saving...")
401
+ with open(save_path_cws_orig, 'w', ) as fp:
402
+ json.dump(seq_data_cws, fp, ensure_ascii=False)
403
+ else:
404
+ print('------------------------ read ----------------------------------')
405
+ save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json')
406
+ print("get the new training data!")
407
+ with open(save_path_cws, 'r', ) as fp:
408
+ seq_data_cws = json.load(fp)
409
+ data_size = len(seq_data_cws)
410
+
411
+ sz_orig = len(seq_data_cws)
412
+ if cfgs.deal_numeric:
413
+ seq_data_cws, num_ref = extract_num(seq_data_cws)
414
+ print(f"过滤了{sz_orig - len(seq_data_cws)}个无效句子")
415
+ data_size = len(seq_data_cws)
416
+
417
+ print('---------------------- generate chinese ref ------------------------------')
418
+ chinese_ref, kpi_info, sten_that_over_maxl = generate_chinese_ref(seq_data_cws, special_token, cfgs.deal_numeric, kpi_dic)
419
+
420
+ if len(sten_that_over_maxl) > 0:
421
+ print(f"{len(sten_that_over_maxl)} over the 500 len!")
422
+ save_path_max = osp.join(cfgs.data_path, f'{seq_data_name}_max_len_500.json')
423
+ with open(save_path_max, 'w') as fp:
424
+ json.dump(sten_that_over_maxl, fp, ensure_ascii=False)
425
+
426
+ if cfgs.deal_numeric:
427
+ print("KPI info combine")
428
+ kpi_ref = kpi_combine(kpi_info, num_ref)
429
+ # pdb.set_trace()
430
+ print('------------------------- match finished ------------------------------')
431
+
432
+ # 输出最后训练的时候用于做wwm的分词
433
+ save_path_ref = osp.join(cfgs.data_path, f'{seq_data_name}_chinese_ref.json')
434
+ with open(save_path_ref, 'w') as fp:
435
+ json.dump(chinese_ref, fp, ensure_ascii=False)
436
+ print(f"save chinese_ref done!")
437
+
438
+ seq_data_cws_output = []
439
+ for i in range(data_size):
440
+ seq = " ".join(seq_data_cws[i])
441
+ seq_data_cws_output.append(seq)
442
+
443
+ save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws.json')
444
+ print("get the new training data!")
445
+ with open(save_path_cws, 'w', ) as fp:
446
+ json.dump(seq_data_cws_output, fp, ensure_ascii=False)
447
+
448
+ print("save seq_data_cws done!")
449
+
450
+ if cfgs.deal_numeric:
451
+ kpi_ref_path = osp.join(cfgs.data_path, f'{seq_data_name}_kpi_ref.json')
452
+ with open(kpi_ref_path, 'w', ) as fp:
453
+ json.dump(kpi_ref, fp, ensure_ascii=False)
454
+ print("save num and kpi done!")
KTeleBERT/main.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import torch
4
+ from torch.utils.tensorboard import SummaryWriter
5
+ from torch.utils.data import DataLoader, RandomSampler
6
+ from torch.cuda.amp import GradScaler, autocast
7
+ from datetime import datetime
8
+ from easydict import EasyDict as edict
9
+ from tqdm import tqdm
10
+ import pdb
11
+ import pprint
12
+ import json
13
+ import pickle
14
+ from collections import defaultdict
15
+ import copy
16
+ from time import time
17
+
18
+ from config import cfg
19
+ from torchlight import initialize_exp, set_seed, get_dump_path
20
+ from src.data import load_data, load_data_kg, Collator_base, Collator_kg, SeqDataset, KGDataset, Collator_order, load_order_data
21
+ from src.utils import set_optim, Loss_log, add_special_token, time_trans
22
+ from src.distributed_utils import init_distributed_mode, dist_pdb, is_main_process, reduce_value, cleanup
23
+ import torch.distributed as dist
24
+
25
+ from itertools import cycle
26
+ from model import BertTokenizer, HWBert, KGEModel, OD_model, KE_model
27
+ import torch.multiprocessing
28
+ from torch.nn.parallel import DistributedDataParallel
29
+
30
+ # 默认用cuda就行
31
+
32
+
33
+ class Runner:
34
+ def __init__(self, args, writer=None, logger=None, rank=0):
35
+ self.datapath = edict()
36
+ self.datapath.log_dir = get_dump_path(args)
37
+ self.datapath.model_dir = os.path.join(self.datapath.log_dir, 'model')
38
+ self.rank = rank
39
+ # init code
40
+ self.mlm_probability = args.mlm_probability
41
+ self.args = args
42
+ self.writer = writer
43
+ self.logger = logger
44
+ # 模型选择
45
+ self.model_list = []
46
+ self.model = HWBert(self.args)
47
+ # 数据加载。添加special_token,同时把模型的embedding layer进行resize
48
+ self.data_init()
49
+ self.model.cuda()
50
+ # 模型加载
51
+ self.od_model, self.ke_model = None, None
52
+ self.scaler = GradScaler()
53
+
54
+ # 只要不是第一种训练策略就有新模型
55
+ if self.args.train_strategy >= 2:
56
+ self.ke_model = KE_model(self.args)
57
+ if self.args.train_strategy >= 3:
58
+ # TODO: 异常检测
59
+ pass
60
+ if self.args.train_strategy >= 4:
61
+ self.od_model = OD_model(self.args)
62
+
63
+ if self.args.model_name not in ['MacBert', 'TeleBert', 'TeleBert2', 'TeleBert3'] and not self.args.from_pretrain:
64
+ # 如果不存在模型会直接返回None或者原始模型
65
+ self.model = self._load_model(self.model, self.args.model_name)
66
+ self.od_model = self._load_model(self.od_model, f"od_{self.args.model_name}")
67
+ self.ke_model = self._load_model(self.ke_model, f"ke_{self.args.model_name}")
68
+ # TODO: 异常检测
69
+
70
+ # 测试的情况
71
+ if self.args.only_test:
72
+ self.dataloader_init(self.seq_test_set)
73
+ else:
74
+ # 训练
75
+ if self.args.ernie_stratege > 0:
76
+ self.args.mask_stratege = 'rand'
77
+ # 初始化dataloader
78
+ self.dataloader_init(self.seq_train_set, self.kg_train_set, self.order_train_set)
79
+ if self.args.dist:
80
+ # 并行训练需要权值共享
81
+ self.model_sync()
82
+ else:
83
+ self.model_list = [model for model in [self.model, self.od_model, self.ke_model] if model is not None]
84
+
85
+ self.optim_init(self.args)
86
+
87
+ def model_sync(self):
88
+ checkpoint_path = osp.join(self.args.data_path, "tmp", "initial_weights.pt")
89
+ checkpoint_path_od = osp.join(self.args.data_path, "tmp", "initial_weights_od.pt")
90
+ checkpoint_path_ke = osp.join(self.args.data_path, "tmp", "initial_weights_ke.pt")
91
+ if self.rank == 0:
92
+ torch.save(self.model.state_dict(), checkpoint_path)
93
+ if self.od_model is not None:
94
+ torch.save(self.od_model.state_dict(), checkpoint_path_od)
95
+ if self.ke_model is not None:
96
+ torch.save(self.ke_model.state_dict(), checkpoint_path_ke)
97
+ dist.barrier()
98
+
99
+ # if self.rank != 0:
100
+ # 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
101
+ self.model = self._model_sync(self.model, checkpoint_path)
102
+ if self.od_model is not None:
103
+ self.od_model = self._model_sync(self.od_model, checkpoint_path_od)
104
+ if self.ke_model is not None:
105
+ self.ke_model = self._model_sync(self.ke_model, checkpoint_path_ke)
106
+
107
+ def _model_sync(self, model, checkpoint_path):
108
+ model.load_state_dict(torch.load(checkpoint_path, map_location=self.args.device))
109
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(self.args.device)
110
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.args.gpu], find_unused_parameters=True)
111
+ self.model_list.append(model)
112
+ model = model.module
113
+ return model
114
+
115
+ def optim_init(self, opt, total_step=None, accumulation_step=None):
116
+ step_per_epoch = len(self.train_dataloader)
117
+ # 占总step 10% 的warmup_steps
118
+ opt.total_steps = int(step_per_epoch * opt.epoch) if total_step is None else int(total_step)
119
+ opt.warmup_steps = int(opt.total_steps * 0.15)
120
+
121
+ if self.rank == 0 and total_step is None:
122
+ self.logger.info(f"warmup_steps: {opt.warmup_steps}")
123
+ self.logger.info(f"total_steps: {opt.total_steps}")
124
+ self.logger.info(f"weight_decay: {opt.weight_decay}")
125
+
126
+ freeze_part = ['bert.encoder.layer.1.', 'bert.encoder.layer.2.', 'bert.encoder.layer.3.', 'bert.encoder.layer.4.'][:self.args.freeze_layer]
127
+ self.optimizer, self.scheduler = set_optim(opt, self.model_list, freeze_part, accumulation_step)
128
+
129
+ def data_init(self):
130
+ # 载入数据, 两部分数据包括:载入mask loss部分的数据(序列化的数据) 和 载入triple loss部分的数据(三元组)
131
+ # train_test_split: 训练集长度
132
+ self.seq_train_set, self.seq_test_set, self.kg_train_set, self.kg_data = None, None, None, None
133
+ self.order_train_set, self.order_test_set = None, None
134
+
135
+ if self.args.train_strategy >= 1 and self.args.train_strategy <= 4:
136
+ # 预训练 or multi task pretrain
137
+ self.seq_train_set, self.seq_test_set, train_test_split = load_data(self.logger, self.args)
138
+ if self.args.train_strategy >= 2:
139
+ self.kg_train_set, self.kg_data = load_data_kg(self.logger, self.args)
140
+ if self.args.train_strategy >= 3:
141
+ # TODO: 异常检测的数据载入
142
+ pass
143
+ if self.args.train_strategy >= 4:
144
+ self.order_train_set, self.order_test_set, train_test_split = load_order_data(self.logger, self.args)
145
+
146
+ if self.args.dist and not self.args.only_test:
147
+ # 测试不需要并行
148
+ if self.args.train_strategy >= 1 and self.args.train_strategy <= 4:
149
+ self.seq_train_sampler = torch.utils.data.distributed.DistributedSampler(self.seq_train_set)
150
+ if self.args.train_strategy >= 2:
151
+ self.kg_train_sampler = torch.utils.data.distributed.DistributedSampler(self.kg_train_set)
152
+ if self.args.train_strategy >= 3:
153
+ # TODO: 异常检测的数据载入
154
+ pass
155
+ if self.args.train_strategy >= 4:
156
+ self.order_train_sampler = torch.utils.data.distributed.DistributedSampler(self.order_train_set)
157
+
158
+ # self.seq_train_batch_sampler = torch.utils.data.BatchSampler(self.seq_train_sampler, self.args.batch_size, drop_last=True)
159
+ # self.kg_train_batch_sampler = torch.utils.data.BatchSampler(self.kg_train_sampler, int(self.args.batch_size / 4), drop_last=True)
160
+
161
+ # Tokenizer 载入
162
+ model_name = self.args.model_name
163
+ if self.args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']:
164
+ self.tokenizer = BertTokenizer.from_pretrained(osp.join(self.args.data_root, 'transformer', model_name), do_lower_case=True)
165
+ else:
166
+ if not osp.exists(osp.join(self.args.data_root, 'transformer', self.args.model_name)):
167
+ model_name = 'MacBert'
168
+ self.tokenizer = BertTokenizer.from_pretrained(osp.join(self.args.data_root, 'transformer', model_name), do_lower_case=True)
169
+
170
+ # 添加special_token,同时把模型的embedding layer进行resize
171
+ self.special_token = None
172
+ # 单纯的telebert在测试时不需要特殊embedding
173
+ if self.args.add_special_word and not (self.args.only_test and self.args.model_name in ['MacBert', 'TeleBert', 'TeleBert2', 'TeleBert3']):
174
+ # tokenizer, special_token, norm_token
175
+ # special_token 不应该被MASK
176
+ self.tokenizer, special_token, _ = add_special_token(self.tokenizer, model=self.model.encoder, rank=self.rank, cache_path=self.args.specail_emb_path)
177
+ # pdb.set_trace()
178
+ self.special_token = [token.lower() for token in special_token]
179
+
180
+ def _dataloader_dist(self, train_set, train_sampler, batch_size, collator):
181
+ train_dataloader = DataLoader(
182
+ train_set,
183
+ sampler=train_sampler,
184
+ pin_memory=True,
185
+ num_workers=self.args.workers,
186
+ persistent_workers=True,
187
+ drop_last=True,
188
+ batch_size=batch_size,
189
+ collate_fn=collator
190
+ )
191
+ return train_dataloader
192
+
193
+ def _dataloader(self, train_set, batch_size, collator):
194
+ train_dataloader = DataLoader(
195
+ train_set,
196
+ num_workers=self.args.workers,
197
+ persistent_workers=True,
198
+ shuffle=(self.args.only_test == 0),
199
+ drop_last=(self.args.only_test == 0),
200
+ batch_size=batch_size,
201
+ collate_fn=collator
202
+ )
203
+ return train_dataloader
204
+
205
+ def dataloader_init(self, train_set=None, kg_train_set=None, order_train_set=None):
206
+ bs = self.args.batch_size
207
+ bs_ke = self.args.batch_size_ke
208
+ bs_od = self.args.batch_size_od
209
+ bs_ad = self.args.batch_size_ad
210
+ # 分布式
211
+ if self.args.dist and not self.args.only_test:
212
+ self.args.workers = min([os.cpu_count(), self.args.batch_size, self.args.workers])
213
+ # if self.rank == 0:
214
+ # print(f'Using {self.args.workers} dataloader workers every process')
215
+
216
+ if train_set is not None:
217
+ seq_collator = Collator_base(self.args, tokenizer=self.tokenizer, special_token=self.special_token)
218
+ self.train_dataloader = self._dataloader_dist(train_set, self.seq_train_sampler, bs, seq_collator)
219
+ if kg_train_set is not None:
220
+ kg_collator = Collator_kg(self.args, tokenizer=self.tokenizer, data=self.kg_data)
221
+ self.train_dataloader_kg = self._dataloader_dist(kg_train_set, self.kg_train_sampler, bs_ke, kg_collator)
222
+ if order_train_set is not None:
223
+ order_collator = Collator_order(self.args, tokenizer=self.tokenizer)
224
+ self.train_dataloader_order = self._dataloader_dist(order_train_set, self.order_train_sampler, bs_od, order_collator)
225
+ else:
226
+ if train_set is not None:
227
+ seq_collator = Collator_base(self.args, tokenizer=self.tokenizer, special_token=self.special_token)
228
+ self.train_dataloader = self._dataloader(train_set, bs, seq_collator)
229
+ if kg_train_set is not None:
230
+ kg_collator = Collator_kg(self.args, tokenizer=self.tokenizer, data=self.kg_data)
231
+ self.train_dataloader_kg = self._dataloader(kg_train_set, bs_ke, kg_collator)
232
+ if order_train_set is not None:
233
+ order_collator = Collator_order(self.args, tokenizer=self.tokenizer)
234
+ self.train_dataloader_order = self._dataloader(order_train_set, bs_od, order_collator)
235
+
236
+ def dist_step(self, task=0):
237
+ # 分布式训练需要额外step
238
+ if self.args.dist:
239
+ if task == 0:
240
+ self.seq_train_sampler.set_epoch(self.dist_epoch)
241
+ if task == 1:
242
+ self.kg_train_sampler.set_epoch(self.dist_epoch)
243
+ if task == 2:
244
+ # TODO:异常检测
245
+ pass
246
+ if task == 3:
247
+ self.order_train_sampler.set_epoch(self.dist_epoch)
248
+ self.dist_epoch += 1
249
+
250
+ def mask_rate_update(self, i):
251
+ # 这种策略是曲线地增加 mask rate
252
+ if self.args.mlm_probability_increase == "curve":
253
+ self.args.mlm_probability += (i + 1) * ((self.args.final_mlm_probability - self.args.mlm_probability) / self.args.epoch)
254
+ # 这种是线性的
255
+ else:
256
+ assert self.args.mlm_probability_increase == "linear"
257
+ self.args.mlm_probability += (self.args.final_mlm_probability - self.mlm_probability) / self.args.epoch
258
+
259
+ if self.rank == 0:
260
+ self.logger.info(f"Moving Mlm_probability in next epoch to: {self.args.mlm_probability*100}%")
261
+
262
+ def task_switch(self, training_strategy):
263
+ # 同时训练或者策略1训练不需要切换任务,epoch也安装初始epoch就行
264
+ if training_strategy == 1 or self.args.train_together:
265
+ return (0, 0), None
266
+
267
+ # 4 阶段
268
+ # self.total_epoch -= 1
269
+
270
+ for i in range(4):
271
+ for task in range(training_strategy):
272
+ if self.args.epoch_matrix[task][i] > 0:
273
+ self.args.epoch_matrix[task][i] -= 1
274
+ return (task, i), self.args.epoch_matrix[task][i] + 1
275
+
276
+ def run(self):
277
+ self.loss_log = Loss_log()
278
+ self.curr_loss = 0.
279
+ self.lr = self.args.lr
280
+ self.curr_loss_dic = defaultdict(float)
281
+ self.curr_kpi_loss_dic = defaultdict(float)
282
+ self.loss_weight = [1, 1]
283
+ self.kpi_loss_weight = [1, 1]
284
+ self.step = 0
285
+ # 不同task 的累计step
286
+ self.total_step_sum = 0
287
+ task_last = 0
288
+ stage_last = 0
289
+ self.dist_epoch = 0
290
+ # 后面可以变成混合训练模式
291
+ # self.total_epoch = self.args.epoch
292
+ # --------- train -------------
293
+ with tqdm(total=self.args.epoch) as _tqdm: # 使用需要的参数对tqdm进行初始化
294
+ for i in range(self.args.epoch):
295
+ # 切换Task
296
+ (task, stage), task_epoch = self.task_switch(self.args.train_strategy)
297
+ self.dist_step(task)
298
+ dataloader = self.task_dataloader_choose(task)
299
+ # 并行
300
+ if self.args.train_together and self.args.train_strategy > 1:
301
+ self.dataloader_list = ['#']
302
+ # 一个list 存下所有需要的dataloader的迭代
303
+ for t in range(1, self.args.train_strategy):
304
+ self.dist_step(t)
305
+ self.dataloader_list.append(iter(self.task_dataloader_choose(t)))
306
+
307
+ if task != task_last or stage != stage_last:
308
+ self.step = 0
309
+ if self.rank == 0:
310
+ print(f"switch to task [{task}] in stage [{stage}]...")
311
+ if stage != stage_last:
312
+ # 每一个阶段结束保存一次
313
+ self._save_model(stage=f'_stg{stage_last}')
314
+ # task 转换状态时需要重新初始化优化器
315
+ # 并行训练或者单一task (0) 训练不需要切换opti
316
+ if task_epoch is not None:
317
+ self.optim_init(self.args, total_step=len(dataloader) * task_epoch, accumulation_step=self.args.accumulation_steps_dict[task])
318
+ task_last = task
319
+ stage_last = stage
320
+
321
+ # 调整学习阶段
322
+ if task == 0 and self.args.ernie_stratege > 0 and i >= self.args.ernie_stratege:
323
+ # 不会再触发第二次
324
+ self.args.ernie_stratege = 10000000
325
+ if self.rank == 0:
326
+ self.logger.info("switch to wwm stratege...")
327
+ self.args.mask_stratege = 'wwm'
328
+
329
+ if self.args.mlm_probability != self.args.final_mlm_probability:
330
+ # 更新 MASK rate
331
+ # 初始化训练数据, 可以随epoch切换
332
+ # 混合训练
333
+ self.mask_rate_update(i)
334
+ self.dataloader_init(self.seq_train_set, self.kg_train_set, self.order_train_set)
335
+ # -------------------------------
336
+ # 针对task 进行训练
337
+ self.train(_tqdm, dataloader, task)
338
+ # -------------------------------
339
+ _tqdm.update(1)
340
+
341
+ # DONE: save or load
342
+ if self.rank == 0:
343
+ self.logger.info(f"min loss {self.loss_log.get_min_loss()}")
344
+ # DONE: save or load
345
+ if not self.args.only_test and self.args.save_model:
346
+ self._save_model()
347
+
348
+ def task_dataloader_choose(self, task):
349
+ self.model.train()
350
+ # 同时训练就用基础dataloader就行
351
+ if task == 0:
352
+ dataloader = self.train_dataloader
353
+ elif task == 1:
354
+ self.ke_model.train()
355
+ dataloader = self.train_dataloader_kg
356
+ elif task == 2:
357
+ pass
358
+ elif task == 3:
359
+ self.od_model.train()
360
+ dataloader = self.train_dataloader_order
361
+ return dataloader
362
+ # one time train
363
+
364
+ def loss_output(self, batch, task):
365
+ # -------- 模型输出 loss --------
366
+ if task == 0:
367
+ # 输出
368
+ _output = self.model(batch)
369
+ loss = _output['loss']
370
+ elif task == 1:
371
+ loss = self.ke_model(batch, self.model)
372
+ elif task == 2:
373
+ pass
374
+ elif task == 3:
375
+ # TODO: finetune的时候多任务 accumulation_steps 自适应
376
+ # OD task
377
+ emb = self.model.cls_embedding(batch[0], tp=self.args.plm_emb_type)
378
+ loss, loss_dic = self.od_model(emb, batch[1].cuda())
379
+ order_score = self.od_model.predict(emb)
380
+ token_right = self.od_model.right_caculate(order_score, batch[1], threshold=0.5)
381
+ self.loss_log.update_token(batch[1].shape[0], [token_right])
382
+ return loss
383
+
384
+ def train(self, _tqdm, dataloader, task=0):
385
+ # cycle train
386
+ loss_weight, kpi_loss_weight, kpi_loss_dict, _output = None, None, None, None
387
+ # dataloader = zip(self.train_dataloader, cycle(self.train_dataloader_kg))
388
+ self.loss_log.acc_init()
389
+ # 如果self.train_dataloader比self.train_dataloader_kg长则会使得后者训练不完全
390
+ accumulation_steps = self.args.accumulation_steps_dict[task]
391
+ torch.cuda.empty_cache()
392
+
393
+ for batch in dataloader:
394
+ # with autocast():
395
+ loss = self.args.mask_loss_scale * self.loss_output(batch, task)
396
+ # 如果是同时训练的话使用迭代器的方法得到另外的epoch
397
+ if self.args.train_together and self.args.train_strategy > 1:
398
+ for t in range(1, self.args.train_strategy):
399
+ try:
400
+ batch = next(self.dataloader_list[t])
401
+ except StopIteration:
402
+ self.dist_step(t)
403
+ self.dataloader_list[t] = iter(self.task_dataloader_choose(t))
404
+ batch = next(self.dataloader_list[t])
405
+ # 选择对应的模型得到loss
406
+ # torch.cuda.empty_cache()
407
+ loss += self.loss_output(batch, t)
408
+ # torch.cuda.empty_cache()
409
+ loss = loss / accumulation_steps
410
+ self.scaler.scale(loss).backward()
411
+ # loss.backward()
412
+ if self.args.dist:
413
+ loss = reduce_value(loss, average=True)
414
+ # torch.cuda.empty_cache()
415
+ self.step += 1
416
+ self.total_step_sum += 1
417
+
418
+ # -------- 模型统计 --------
419
+ if not self.args.dist or is_main_process():
420
+ self.output_statistic(loss, _output)
421
+ acc_descrip = f"Acc: {self.loss_log.get_token_acc()}" if self.loss_log.get_token_acc() > 0 else ""
422
+ _tqdm.set_description(f'Train | step [{self.step}/{self.args.total_steps}] {acc_descrip} LR [{self.lr}] Loss {self.loss_log.get_loss():.5f} ')
423
+ if self.step % self.args.eval_step == 0 and self.step > 0:
424
+ self.loss_log.update(self.curr_loss)
425
+ self.update_loss_log()
426
+ # -------- 梯度累计与模型更新 --------
427
+ if self.step % accumulation_steps == 0 and self.step > 0:
428
+ # 更新优化器
429
+ self.scaler.unscale_(self.optimizer)
430
+ for model in self.model_list:
431
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.clip)
432
+
433
+ # self.optimizer.step()
434
+ scale = self.scaler.get_scale()
435
+ self.scaler.step(self.optimizer)
436
+
437
+ self.scaler.update()
438
+ skip_lr_sched = (scale > self.scaler.get_scale())
439
+ if not skip_lr_sched:
440
+ # pdb.set_trace()
441
+ self.scheduler.step()
442
+
443
+ if not self.args.dist or is_main_process():
444
+ # pdb.set_trace()
445
+ self.lr = self.scheduler.get_last_lr()[-1]
446
+ self.writer.add_scalars("lr", {"lr": self.lr}, self.total_step_sum)
447
+ # 模型update
448
+ for model in self.model_list:
449
+ model.zero_grad(set_to_none=True)
450
+
451
+ if self.args.dist:
452
+ torch.cuda.synchronize(self.args.device)
453
+ return self.curr_loss, self.curr_loss_dic
454
+
455
+ def output_statistic(self, loss, output):
456
+ # 统计模型的各种输出
457
+ self.curr_loss += loss.item()
458
+ if output is None:
459
+ return
460
+ for key in output['loss_dic'].keys():
461
+ self.curr_loss_dic[key] += output['loss_dic'][key]
462
+ if 'kpi_loss_dict' in output and output['kpi_loss_dict'] is not None:
463
+ for key in output['kpi_loss_dict'].keys():
464
+ self.curr_kpi_loss_dic[key] += output['kpi_loss_dict'][key]
465
+ if 'loss_weight' in output and output['loss_weight'] is not None:
466
+ self.loss_weight = output['loss_weight']
467
+ # 需要用dict来判断
468
+ if 'kpi_loss_weight' in output and output['kpi_loss_weight'] is not None:
469
+ self.kpi_loss_weight = output['kpi_loss_weight']
470
+
471
+ def update_loss_log(self, task=0):
472
+ # 把统计的模型各种输出存下来
473
+ # https://zhuanlan.zhihu.com/p/382950853
474
+ # "mask_loss": self.curr_loss_dic['mask_loss'], "ke_loss": self.curr_loss_dic['ke_loss']
475
+ vis_dict = {"train_loss": self.curr_loss}
476
+ vis_dict.update(self.curr_loss_dic)
477
+ self.writer.add_scalars("loss", vis_dict, self.total_step_sum)
478
+ if self.loss_weight is not None:
479
+ # 预训练
480
+ loss_weight_dic = {}
481
+ if self.args.train_strategy == 1:
482
+ loss_weight_dic["mask"] = 1 / (self.loss_weight[0]**2)
483
+ if self.args.use_NumEmb:
484
+ loss_weight_dic["kpi"] = 1 / (self.loss_weight[1]**2)
485
+ vis_kpi_dic = {"recover": 1 / (self.kpi_loss_weight[0]**2), "classifier": 1 / (self.kpi_loss_weight[1]**2)}
486
+ if self.args.contrastive_loss and len(self.kpi_loss_weight) > 2:
487
+ vis_kpi_dic.update({"contrastive": 1 / (self.kpi_loss_weight[2]**2)})
488
+ self.writer.add_scalars("kpi_loss_weight", vis_kpi_dic, self.total_step_sum)
489
+ self.writer.add_scalars("kpi_loss", self.curr_kpi_loss_dic, self.total_step_sum)
490
+ self.writer.add_scalars("loss_weight", loss_weight_dic, self.total_step_sum)
491
+ # TODO: Finetune
492
+
493
+ # init log loss
494
+ self.curr_loss = 0.
495
+ for key in self.curr_loss_dic:
496
+ self.curr_loss_dic[key] = 0.
497
+ if len(self.curr_kpi_loss_dic) > 0:
498
+ for key in self.curr_kpi_loss_dic:
499
+ self.curr_kpi_loss_dic[key] = 0.
500
+
501
+ # TODO: Finetune 阶段
502
+ def eval(self):
503
+ self.model.eval()
504
+ torch.cuda.empty_cache()
505
+
506
+ def mask_test(self, test_log):
507
+ # 如果大于1 就无法mask测试
508
+ assert self.args.train_ratio < 1
509
+ topk = (1, 100, 500)
510
+ test_log.acc_init(topk)
511
+ # 做 mask 预测的时候需要进入训练模式,以获得随机mask的token
512
+ self.args.only_test = 0
513
+ self.dataloader_init(self.seq_test_set)
514
+ # pdb.set_trace()
515
+ sz_test = len(self.train_dataloader)
516
+ loss_sum = 0
517
+ with tqdm(total=sz_test) as _tqdm: # 使用需要的参数对tqdm进行初始化
518
+ for step, batch in enumerate(self.train_dataloader):
519
+ # DONE: 写好mask_prediction实现mask预测
520
+ with torch.no_grad():
521
+ token_num, token_right, loss = self.model.mask_prediction(batch, len(self.tokenizer), topk)
522
+ test_log.update_token(token_num, token_right)
523
+ loss_sum += loss
524
+ # test_log.update_word(word_num, word_right)
525
+ _tqdm.update(1)
526
+ _tqdm.set_description(f'Test | step [{step}/{sz_test}] Top{topk} Token_Acc: {test_log.get_token_acc()}')
527
+ print(f"perplexity: {loss_sum}")
528
+ # 训练模式复位
529
+ self.args.only_test = 1
530
+ # if topk is not None:
531
+ print(f"Top{topk} acc is {test_log.get_token_acc()}")
532
+
533
+ def emb_generate(self, path_gen):
534
+ assert len(self.args.path_gen) > 0 or path_gen is not None
535
+ data_path = self.args.data_path
536
+ if path_gen is None:
537
+ path_gen = self.args.path_gen
538
+ with open(osp.join(data_path, 'downstream_task', f'{path_gen}.json'), "r") as fp:
539
+ data = json.load(fp)
540
+ print(f"read file {path_gen} done!")
541
+ test_set = SeqDataset(data)
542
+ self.dataloader_init(test_set)
543
+ sz_test = len(self.train_dataloader)
544
+ all_emb_dic = defaultdict(list)
545
+ emb_output = {}
546
+ all_emb_ent = []
547
+ # tps = ['cls', 'last_avg', 'last2avg', 'last3avg', 'first_last_avg']
548
+ tps = ['cls', 'last_avg']
549
+ # with tqdm(total=sz_test) as _tqdm:
550
+ for step, batch in enumerate(self.train_dataloader):
551
+ for tp in tps:
552
+ with torch.no_grad():
553
+ batch_embedding = self.model.cls_embedding(batch, tp=tp)
554
+ # batch_embedding = self.model.cls_embedding(batch, tp=tp)
555
+ if tp in self.args.model_name and self.ke_model is not None:
556
+ batch_embedding_ent = self.ke_model.get_embedding(batch_embedding, is_ent=True)
557
+ # batch_embedding_ent = self.ke_model(batch, self.model)
558
+ batch_embedding_ent = batch_embedding_ent.cpu()
559
+ all_emb_ent.append(batch_embedding_ent)
560
+
561
+ batch_embedding = batch_embedding.cpu()
562
+ all_emb_dic[tp].append(batch_embedding)
563
+ # _tqdm.update(1)
564
+ # _tqdm.set_description(f'Test | step [{step}/{sz_test}]')
565
+ torch.cuda.empty_cache()
566
+ for tp in tps:
567
+ emb_output[tp] = torch.cat(all_emb_dic[tp])
568
+ assert emb_output[tp].shape[0] == len(data)
569
+ if len(all_emb_ent) > 0:
570
+ emb_output_ent = torch.cat(all_emb_ent)
571
+ # 后缀
572
+ save_path = osp.join(data_path, 'downstream_task', 'output')
573
+ os.makedirs(save_path, exist_ok=True)
574
+ for tp in tps:
575
+ save_dir = osp.join(save_path, f'{path_gen}_emb_{self.args.model_name.replace("DistributedDataParallel", "")}_{tp}.pt')
576
+ torch.save(emb_output[tp], save_dir)
577
+ # 有训练好的实体embedding可使用
578
+ if len(all_emb_ent) > 0:
579
+ save_dir = osp.join(save_path, f'{path_gen}_emb_{self.args.model_name.replace("DistributedDataParallel", "")}_ent.pt')
580
+ torch.save(emb_output_ent, save_dir)
581
+
582
+ def KGE_test(self):
583
+ # 直接用KG全集进行kge的测试
584
+ sz_test = len(self.kg_train_set)
585
+ # 先转换数据
586
+ ent_set = set()
587
+ rel_set = set()
588
+ with tqdm(total=sz_test) as _tqdm: # 使用需要的参数对tqdm进行初始化
589
+ _tqdm.set_description('trans entity/relation ID')
590
+ for batch in self.kg_train_set:
591
+ ent_set.add(batch[0])
592
+ ent_set.add(batch[2])
593
+ rel_set.add(batch[1])
594
+ _tqdm.update(1)
595
+ all_ent, all_rel = list(ent_set), list(rel_set)
596
+ nent, nrel = len(all_ent), len(all_rel)
597
+ ent_dic, rel_dic = {}, {}
598
+ for i in range(nent):
599
+ ent_dic[all_ent[i]] = i
600
+ for i in range(nrel):
601
+ rel_dic[all_rel[i]] = i
602
+ id_format_triple = []
603
+ with tqdm(total=sz_test) as _tqdm:
604
+ _tqdm.set_description('trans triple ID')
605
+ for triple in self.kg_train_set:
606
+ id_format_triple.append((ent_dic[triple[0]], rel_dic[triple[1]], ent_dic[triple[2]]))
607
+ _tqdm.update(1)
608
+
609
+ # pdb.set_trace()
610
+ # 生成实体embedding并且保存
611
+ ent_dataset = KGDataset(all_ent)
612
+ rel_dataset = KGDataset(all_rel)
613
+
614
+ ent_dataloader = DataLoader(
615
+ ent_dataset,
616
+ batch_size=self.args.batch_size * 32,
617
+ num_workers=self.args.workers,
618
+ persistent_workers=True,
619
+ shuffle=False
620
+ )
621
+ rel_dataloader = DataLoader(
622
+ rel_dataset,
623
+ batch_size=self.args.batch_size * 32,
624
+ num_workers=self.args.workers,
625
+ persistent_workers=True,
626
+ shuffle=False
627
+ )
628
+
629
+ sz_test = len(ent_dataloader) + len(rel_dataloader)
630
+ with tqdm(total=sz_test) as _tqdm:
631
+ ent_emb = []
632
+ rel_emb = []
633
+ step = 0
634
+ _tqdm.set_description('get the ent embedding')
635
+ with torch.no_grad():
636
+ for batch in ent_dataloader:
637
+ batch = self.tokenizer.batch_encode_plus(
638
+ batch,
639
+ padding='max_length',
640
+ max_length=self.args.maxlength,
641
+ truncation=True,
642
+ return_tensors="pt",
643
+ return_token_type_ids=False,
644
+ return_attention_mask=True,
645
+ add_special_tokens=False
646
+ )
647
+
648
+ batch_emb = self.model.cls_embedding(batch, tp=self.args.plm_emb_type)
649
+ batch_emb = self.ke_model.get_embedding(batch_emb, is_ent=True)
650
+
651
+ ent_emb.append(batch_emb.cpu())
652
+ _tqdm.update(1)
653
+ step += 1
654
+ torch.cuda.empty_cache()
655
+ _tqdm.set_description(f'ENT emb: [{step}/{sz_test}]')
656
+
657
+ _tqdm.set_description('get the rel embedding')
658
+ for batch in rel_dataloader:
659
+ batch = self.tokenizer.batch_encode_plus(
660
+ batch,
661
+ padding='max_length',
662
+ max_length=self.args.maxlength,
663
+ truncation=True,
664
+ return_tensors="pt",
665
+ return_token_type_ids=False,
666
+ return_attention_mask=True,
667
+ add_special_tokens=False
668
+ )
669
+ batch_emb = self.model.cls_embedding(batch, tp=self.args.plm_emb_type)
670
+ batch_emb = self.ke_model.get_embedding(batch_emb, is_ent=False)
671
+ # batch_emb = self.model.get_embedding(batch, is_ent=False)
672
+ rel_emb.append(batch_emb.cpu())
673
+ _tqdm.update(1)
674
+ step += 1
675
+ torch.cuda.empty_cache()
676
+ _tqdm.set_description(f'REL emb: [{step}/{sz_test}]')
677
+
678
+ all_ent_emb = torch.cat(ent_emb).cuda()
679
+ all_rel_emb = torch.cat(rel_emb).cuda()
680
+ # embedding:emb_output
681
+ # dim 256
682
+ kge_model_for_test = KGEModel(nentity=len(all_ent), nrelation=len(all_rel), hidden_dim=self.args.ke_dim,
683
+ gamma=self.args.ke_margin, entity_embedding=all_ent_emb, relation_embedding=all_rel_emb).cuda()
684
+ if self.args.ke_test_num > 0:
685
+ test_triples = id_format_triple[:self.args.ke_test_num]
686
+ else:
687
+ test_triples = id_format_triple
688
+ with torch.no_grad():
689
+ metrics = kge_model_for_test.test_step(test_triples=test_triples, all_true_triples=id_format_triple, args=self.args, nentity=len(all_ent), nrelation=len(all_rel))
690
+ # pdb.set_trace()
691
+ print(f"result:{metrics}")
692
+
693
+ def OD_test(self):
694
+ # data_path = self.args.data_path
695
+ # with open(osp.join(data_path, f'{self.args.order_test_name}.json'), "r") as fp:
696
+ # data = json.load(fp)
697
+ self.od_model.eval()
698
+ test_log = Loss_log()
699
+ test_log.acc_init()
700
+ sz_test = len(self.train_dataloader)
701
+ all_emb_ent = []
702
+ with tqdm(total=sz_test) as _tqdm: # 使用需要的参数对tqdm进行初始化
703
+ for step, batch in enumerate(self.train_dataloader):
704
+ with torch.no_grad():
705
+ emb = self.model.cls_embedding(batch[0], tp=self.args.plm_emb_type)
706
+ out_emb = self.od_model.encode(emb)
707
+ emb_cpu = out_emb.cpu()
708
+ all_emb_ent.append(emb_cpu)
709
+ order_score = self.od_model.predict(emb)
710
+ token_right = self.od_model.right_caculate(order_score, batch[1], threshold=self.args.order_threshold)
711
+ test_log.update_token(batch[1].shape[0], [token_right])
712
+ _tqdm.update(1)
713
+ _tqdm.set_description(f'Test | step [{step}/{sz_test}] Acc: {test_log.get_token_acc()}')
714
+
715
+ emb_output = torch.cat(all_emb_ent)
716
+ data_path = self.args.data_path
717
+ save_path = osp.join(data_path, 'downstream_task', 'output')
718
+ os.makedirs(save_path, exist_ok=True)
719
+ save_dir = osp.join(save_path, f'ratio{self.args.train_ratio}_{emb_output.shape[0]}emb_{self.args.model_name.replace("DistributedDataParallel", "")}.pt')
720
+ torch.save(emb_output, save_dir)
721
+ print(f"save {emb_output.shape[0]} embeddings done...")
722
+
723
+ @ torch.no_grad()
724
+ def test(self, path_gen=None):
725
+ test_log = Loss_log()
726
+ self.model.eval()
727
+ if not (self.args.mask_test or self.args.embed_gen or self.args.ke_test or len(self.args.order_test_name) > 0):
728
+ return
729
+ if self.args.mask_test:
730
+ self.mask_test(test_log)
731
+ if self.args.embed_gen:
732
+ self.emb_generate(path_gen)
733
+ if self.args.ke_test:
734
+ self.KGE_test()
735
+ if len(self.args.order_test_name) > 0:
736
+ runner.OD_test()
737
+
738
+ def _load_model(self, model, name):
739
+ if model is None:
740
+ return None
741
+ # 没有训练过
742
+ _name = name if name[:3] not in ['od_', 'ke_'] else name[3:]
743
+ save_path = osp.join(self.args.data_path, 'save', _name)
744
+ save_name = osp.join(save_path, f'{name}.pkl')
745
+ if not osp.exists(save_path) or not osp.exists(save_name):
746
+ return model.cuda()
747
+ # 载入模型
748
+ if 'Distribute' in self.args.model_name:
749
+ model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(os.path.join(save_name), map_location=self.args.device).items()})
750
+ else:
751
+ model.load_state_dict(torch.load(save_name, map_location=self.args.device))
752
+ model.cuda()
753
+ if self.rank == 0:
754
+ print(f"loading model [{name}.pkl] done!")
755
+
756
+ return model
757
+
758
+ def _save_model(self, stage=''):
759
+ model_name = type(self.model).__name__
760
+ # TODO: path
761
+ save_path = osp.join(self.args.data_path, 'save')
762
+ os.makedirs(save_path, exist_ok=True)
763
+ if self.args.train_strategy == 1:
764
+ save_name = f'{self.args.exp_name}_{self.args.exp_id}_s{self.args.random_seed}{stage}'
765
+ else:
766
+ save_name = f'{self.args.exp_name}_{self.args.exp_id}_s{self.args.random_seed}_{self.args.plm_emb_type}{stage}'
767
+ save_path = osp.join(save_path, save_name)
768
+ os.makedirs(save_path, exist_ok=True)
769
+ # 预训练模型保存
770
+ self._save(self.model, save_path, save_name)
771
+
772
+ # 下游模型保存
773
+ save_name_od = f'od_{save_name}'
774
+ self._save(self.od_model, save_path, save_name_od)
775
+ save_name_ke = f'ke_{save_name}'
776
+ self._save(self.ke_model, save_path, save_name_ke)
777
+ return save_path
778
+
779
+ def _save(self, model, save_path, save_name):
780
+ if model is None:
781
+ return
782
+ if self.args.save_model:
783
+ torch.save(model.state_dict(), osp.join(save_path, f'{save_name}.pkl'))
784
+ print(f"saving {save_name} done!")
785
+
786
+ if self.args.save_pretrain and not save_name.startswith('od_') and not save_name.startswith('ke_'):
787
+ self.tokenizer.save_pretrained(osp.join(self.args.plm_path, f'{save_name}'))
788
+ self.model.encoder.save_pretrained(osp.join(self.args.plm_path, f'{save_name}'))
789
+ print(f"saving [pretrained] {save_name} done!")
790
+
791
+
792
+ if __name__ == '__main__':
793
+ cfg = cfg()
794
+ cfg.get_args()
795
+ cfgs = cfg.update_train_configs()
796
+ set_seed(cfgs.random_seed)
797
+ # 初始化各进程环境
798
+ # pdb.set_trace()
799
+ if cfgs.dist and not cfgs.only_test:
800
+ init_distributed_mode(args=cfgs)
801
+ # cfgs.lr *= cfgs.world_size
802
+ # cfgs.ke_lr *= cfgs.world_size
803
+ else:
804
+ # 下面这条语句在并行的时候可能内存泄漏,导致无法停止
805
+ torch.multiprocessing.set_sharing_strategy('file_system')
806
+ rank = cfgs.rank
807
+
808
+ writer, logger = None, None
809
+ if rank == 0:
810
+ # 如果并行则只有一种情况打印
811
+ logger = initialize_exp(cfgs)
812
+ logger_path = get_dump_path(cfgs)
813
+ cfgs.time_stamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
814
+ comment = f'bath_size={cfgs.batch_size} exp_id={cfgs.exp_id}'
815
+ if not cfgs.no_tensorboard and not cfgs.only_test:
816
+ writer = SummaryWriter(log_dir=os.path.join(logger_path, 'tensorboard', cfgs.time_stamp), comment=comment)
817
+
818
+ cfgs.device = torch.device(cfgs.device)
819
+
820
+ # ----- Begin ----------
821
+ runner = Runner(cfgs, writer, logger, rank)
822
+
823
+ if cfgs.only_test:
824
+ if cfgs.embed_gen:
825
+ # 不需要生成的先搞定
826
+ if cfgs.mask_test or cfgs.ke_test:
827
+ runner.args.embed_gen = 0
828
+ runner.test()
829
+ runner.args.embed_gen = 1
830
+ # gen_dir = ['yht_data_merge', 'yht_data_whole5gc', 'yz_data_whole5gc', 'yz_data_merge', 'zyc_data_merge', 'zyc_data_whole5gc']
831
+ gen_dir = ['yht_serialize_withAttribute', 'yht_serialize_withoutAttr', 'yht_name_serialize', 'zyc_serialize_withAttribute', 'zyc_serialize_withoutAttr', 'zyc_name_serialize',
832
+ 'yz_serialize_withAttribute', 'yz_serialize_withoutAttr', 'yz_name_serialize', 'yz_serialize_net']
833
+ # gen_dir = ['zyc_serialize_withAttribute', 'zyc_normal_serialize', 'zyc_data_whole5gc', 'zyc_data_merge', 'yht_normal_serialize',
834
+ # 'yht_serialize_withAttribute', 'yz_serialize_withAttribute', 'yz_serialize_net', 'yz_normal_serialize']
835
+ runner.args.mask_test, runner.args.ke_test = 0, 0
836
+ for item in gen_dir:
837
+ runner.test(item)
838
+ else:
839
+ runner.test()
840
+ else:
841
+ runner.run()
842
+
843
+ # ----- End ----------
844
+ if not cfgs.no_tensorboard and not cfgs.only_test and rank == 0:
845
+ writer.close()
846
+ logger.info("done!")
847
+
848
+ if cfgs.dist and not cfgs.only_test:
849
+ dist.barrier()
850
+ dist.destroy_process_group()
851
+ # print("shut down...")
KTeleBERT/model/HWBert.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import pdb
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from random import *
8
+ import json
9
+ from packaging import version
10
+ import torch.distributed as dist
11
+
12
+ from .Tool_model import AutomaticWeightedLoss
13
+ from .Numeric import AttenNumeric
14
+ from .KE_model import KE_model
15
+ # from modeling_transformer import Transformer
16
+
17
+
18
+ from .bert import BertModel, BertTokenizer, BertForMaskedLM, BertConfig
19
+ import torch.nn.functional as F
20
+
21
+ from copy import deepcopy
22
+ from src.utils import torch_accuracy
23
+ # 4.21.2
24
+
25
+
26
+ def debug(input, kk, begin=None):
27
+ aaa = deepcopy(input[0])
28
+ if begin is None:
29
+ aaa.input_ids = input[0].input_ids[:kk]
30
+ aaa.attention_mask = input[0].attention_mask[:kk]
31
+ aaa.chinese_ref = input[0].chinese_ref[:kk]
32
+ aaa.kpi_ref = input[0].kpi_ref[:kk]
33
+ aaa.labels = input[0].labels[:kk]
34
+ else:
35
+ aaa.input_ids = input[0].input_ids[begin:kk]
36
+ aaa.attention_mask = input[0].attention_mask[begin:kk]
37
+ aaa.chinese_ref = input[0].chinese_ref[begin:kk]
38
+ aaa.kpi_ref = input[0].kpi_ref[begin:kk]
39
+ aaa.labels = input[0].labels[begin:kk]
40
+
41
+ return aaa
42
+
43
+
44
+ class HWBert(nn.Module):
45
+ def __init__(self, args):
46
+ super().__init__()
47
+ self.loss_awl = AutomaticWeightedLoss(args.awl_num, args)
48
+ self.args = args
49
+ self.config = BertConfig()
50
+ model_name = args.model_name
51
+ if args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']:
52
+ self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
53
+ # MacBert来初始化 predictions layer
54
+ if args.cls_head_init:
55
+ tmp = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', 'MacBert'))
56
+ self.encoder.cls.predictions = tmp.cls.predictions
57
+ else:
58
+ if not osp.exists(osp.join(args.data_root, 'transformer', args.model_name)):
59
+ model_name = 'MacBert'
60
+ self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
61
+ self.numeric_model = AttenNumeric(self.args)
62
+
63
+ # ----------------------- 主forward函数 ----------------------------------
64
+ def forward(self, input):
65
+ mask_loss, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(input)
66
+ mask_loss = mask_loss.loss
67
+ loss_dic = {}
68
+ if not self.args.use_kpi_loss:
69
+ kpi_loss = None
70
+ if kpi_loss is not None:
71
+ loss_sum = self.loss_awl(mask_loss, 0.3 * kpi_loss)
72
+ loss_dic['kpi_loss'] = kpi_loss.item()
73
+ else:
74
+ loss_sum = self.loss_awl(mask_loss)
75
+ loss_dic['mask_loss'] = mask_loss.item()
76
+ return {
77
+ 'loss': loss_sum,
78
+ 'loss_dic': loss_dic,
79
+ 'loss_weight': self.loss_awl.params.tolist(),
80
+ 'kpi_loss_weight': kpi_loss_weight,
81
+ 'kpi_loss_dict': kpi_loss_dict
82
+ }
83
+
84
+ # loss_sum, loss_dic, self.loss_awl.params.tolist(), kpi_loss_weight, kpi_loss_dict
85
+
86
+ # ----------------------------------------------------------------
87
+ # 测试代码,计算mask是否正确
88
+ def mask_prediction(self, inputs, tokenizer_sz, topk=(1,)):
89
+ token_num, token_right, word_num, word_right = None, None, None, None
90
+ outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(inputs)
91
+ inputs = inputs['labels'].view(-1)
92
+ input_list = inputs.tolist()
93
+ # 被修改的词
94
+ change_token_index = [i for i, x in enumerate(input_list) if x != -100]
95
+ change_token = torch.tensor(change_token_index)
96
+ inputs_used = inputs[change_token]
97
+ pred = outputs.logits.view(-1, tokenizer_sz)
98
+ pred_used = pred[change_token].cpu()
99
+ # 返回的list
100
+ # 计算acc
101
+ acc, token_right = torch_accuracy(pred_used, inputs_used, topk)
102
+ # 计算混乱分数
103
+
104
+ token_num = inputs_used.shape[0]
105
+ # TODO: 添加word_num, word_right
106
+ # token_right:list
107
+ return token_num, token_right, outputs.loss.item()
108
+
109
+ def mask_forward(self, inputs):
110
+ kpi_ref = None
111
+ if 'kpi_ref' in inputs:
112
+ kpi_ref = inputs['kpi_ref']
113
+
114
+ outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.encoder(
115
+ input_ids=inputs['input_ids'].cuda(),
116
+ attention_mask=inputs['attention_mask'].cuda(),
117
+ # token_type_ids=inputs.token_type_ids.cuda(),
118
+ labels=inputs['labels'].cuda(),
119
+ kpi_ref=kpi_ref,
120
+ kpi_model=self.numeric_model
121
+ )
122
+ return outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict
123
+
124
+ # TODO: 垂直注意力考虑:https://github.com/lucidrains/axial-attention
125
+
126
+ def cls_embedding(self, inputs, tp='cls'):
127
+ hidden_states = self.encoder(
128
+ input_ids=inputs['input_ids'].cuda(),
129
+ attention_mask=inputs['attention_mask'].cuda(),
130
+ output_hidden_states=True)[0].hidden_states
131
+ if tp == 'cls':
132
+ return hidden_states[-1][:, 0]
133
+ else:
134
+ index_real = torch.tensor(inputs['input_ids'].clone().detach(), dtype=torch.bool)
135
+ res = []
136
+ for i in range(hidden_states[-1].shape[0]):
137
+ if tp == 'last_avg':
138
+ res.append(hidden_states[-1][i][index_real[i]][:-1].mean(dim=0))
139
+ elif tp == 'last2avg':
140
+ res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1]).mean(dim=0))
141
+ elif tp == 'last3avg':
142
+ res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1] + hidden_states[-3][i][index_real[i]][:-1]).mean(dim=0))
143
+ elif tp == 'first_last_avg':
144
+ res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[1][i][index_real[i]][:-1]).mean(dim=0))
145
+
146
+ return torch.stack(res)
KTeleBERT/model/KE_model.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from sklearn.metrics import average_precision_score
6
+ from tqdm import tqdm
7
+ import pdb
8
+ from torch.utils.data import DataLoader
9
+ from collections import defaultdict
10
+ import os.path as osp
11
+ import json
12
+
13
+
14
+ class KE_model(nn.Module):
15
+ def __init__(self, args):
16
+ super().__init__()
17
+ """
18
+ triple task: mask tail entity, total entity size-class classification
19
+ """
20
+ """
21
+ :param hidden: BERT model output size
22
+ """
23
+ self.args = args
24
+ self.ke_dim = args.ke_dim
25
+
26
+ self.linear_ent = nn.Linear(args.hidden_size, self.ke_dim)
27
+ self.linear_rel = nn.Linear(args.hidden_size, self.ke_dim)
28
+
29
+ self.ke_margin = nn.Parameter(
30
+ torch.Tensor([args.ke_margin]),
31
+ requires_grad=False
32
+ )
33
+
34
+ def forward(self, batch, hw_model):
35
+ batch_triple = batch
36
+ pos_sample = batch_triple["positive_sample"]
37
+ neg_sample = batch_triple["negative_sample"]
38
+ neg_index = batch_triple["neg_index"]
39
+
40
+ # 节省显存
41
+ all_entity = []
42
+ all_entity_mask = []
43
+ for i in range(3):
44
+ all_entity.append(pos_sample[i]['input_ids'])
45
+ all_entity_mask.append(pos_sample[i]['attention_mask'])
46
+
47
+ all_entity = torch.cat(all_entity)
48
+ all_entity_mask = torch.cat(all_entity_mask)
49
+ entity_data = {'input_ids':all_entity, 'attention_mask':all_entity_mask}
50
+ entity_emb = hw_model.cls_embedding(entity_data, tp=self.args.plm_emb_type)
51
+
52
+ bs = pos_sample[0]['input_ids'].shape[0]
53
+ pos_sample_emb= [entity_emb[:bs], entity_emb[bs:2*bs], entity_emb[2*bs:3*bs]]
54
+ neg_sample_emb = entity_emb[neg_index]
55
+ mode = batch_triple["mode"]
56
+ # pos_score = self.get_score(pos_sample, hw_model)
57
+ # neg_score = self.get_score(pos_sample, hw_model, neg_sample, mode)
58
+ pos_score = self.get_score(pos_sample_emb, hw_model)
59
+ neg_score = self.get_score(pos_sample_emb, hw_model, neg_sample_emb, mode)
60
+ triple_loss = self.adv_loss(pos_score, neg_score, self.args)
61
+
62
+ return triple_loss
63
+
64
+ # pdb.set_trace()
65
+ # return emb.div_(emb.detach().norm(p=1, dim=-1, keepdim=True))
66
+
67
+ # KE loss
68
+ def tri2emb(self, triples, hw_model, negs=None, mode="single"):
69
+ """Get embedding of triples.
70
+ This function get the embeddings of head, relation, and tail
71
+ respectively. each embedding has three dimensions.
72
+ Args:
73
+ triples (tensor): This tensor save triples id, which dimension is
74
+ [triples number, 3].
75
+ negs (tensor, optional): This tenosr store the id of the entity to
76
+ be replaced, which has one dimension. when negs is None, it is
77
+ in the test/eval phase. Defaults to None.
78
+ mode (str, optional): This arg indicates that the negative entity
79
+ will replace the head or tail entity. when it is 'single', it
80
+ means that entity will not be replaced. Defaults to 'single'.
81
+ Returns:
82
+ head_emb: Head entity embedding.
83
+ relation_emb: Relation embedding.
84
+ tail_emb: Tail entity embedding.
85
+ """
86
+
87
+ if mode == "single":
88
+ head_emb = self.get_embedding(triples[0]).unsqueeze(1) # [bs, 1, dim]
89
+ relation_emb = self.get_embedding(triples[1], is_ent=False).unsqueeze(1) # [bs, 1, dim]
90
+ tail_emb = self.get_embedding(triples[2]).unsqueeze(1) # [bs, 1, dim]
91
+
92
+ elif mode == "head-batch" or mode == "head_predict":
93
+ if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding
94
+ # TODO:暂时不考虑KGC的测试情况
95
+ head_emb = self.ent_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
96
+ else:
97
+ head_emb = self.get_embedding(negs).reshape(-1, self.args.neg_num, self.args.ke_dim) # [bs, num_neg, dim]
98
+ relation_emb = self.get_embedding(triples[1], is_ent=False).unsqueeze(1) # [bs, 1, dim]
99
+ tail_emb = self.get_embedding(triples[2]).unsqueeze(1) # [bs, 1, dim]
100
+
101
+ elif mode == "tail-batch" or mode == "tail_predict":
102
+ head_emb = self.get_embedding(triples[0]).unsqueeze(1) # [bs, 1, dim]
103
+ relation_emb = self.get_embedding(triples[1], is_ent=False).unsqueeze(1) # [bs, 1, dim]
104
+ if negs is None:
105
+ tail_emb = self.ent_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
106
+ else:
107
+ # pdb.set_trace()
108
+ tail_emb = self.get_embedding(negs).reshape(-1, self.args.neg_num, self.args.ke_dim) # [bs, num_neg, dim]
109
+
110
+ return head_emb, relation_emb, tail_emb
111
+
112
+ def get_embedding(self, inputs, is_ent=True):
113
+ # pdb.set_trace()
114
+ if is_ent:
115
+ return self.linear_ent(inputs)
116
+ else:
117
+ return self.linear_rel(inputs)
118
+
119
+ def score_func(self, head_emb, relation_emb, tail_emb):
120
+ """Calculating the score of triples.
121
+
122
+ The formula for calculating the score is :math:`\gamma - ||h + r - t||_F`
123
+ Args:
124
+ head_emb: The head entity embedding.
125
+ relation_emb: The relation embedding.
126
+ tail_emb: The tail entity embedding.
127
+ mode: Choose head-predict or tail-predict.
128
+ Returns:
129
+ score: The score of triples.
130
+ """
131
+ score = (head_emb + relation_emb) - tail_emb
132
+ # pdb.set_trace()
133
+ score = self.ke_margin.item() - torch.norm(score, p=1, dim=-1)
134
+ return score
135
+
136
+ def get_score(self, triples, hw_model, negs=None, mode='single'):
137
+ """The functions used in the training phase
138
+
139
+ Args:
140
+ triples: The triples ids, as (h, r, t), shape:[batch_size, 3].
141
+ negs: Negative samples, defaults to None.
142
+ mode: Choose head-predict or tail-predict, Defaults to 'single'.
143
+
144
+ Returns:
145
+ score: The score of triples.
146
+ """
147
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, hw_model, negs, mode)
148
+ score = self.score_func(head_emb, relation_emb, tail_emb)
149
+
150
+ return score
151
+
152
+ def adv_loss(self, pos_score, neg_score, args):
153
+ """Negative sampling loss with self-adversarial training. In math:
154
+
155
+ L=-\log \sigma\left(\gamma-d_{r}(\mathbf{h}, \mathbf{t})\right)-\sum_{i=1}^{n} p\left(h_{i}^{\prime}, r, t_{i}^{\prime}\right) \log \sigma\left(d_{r}\left(\mathbf{h}_{i}^{\prime}, \mathbf{t}_{i}^{\prime}\right)-\gamma\right)
156
+
157
+ Args:
158
+ pos_score: The score of positive samples.
159
+ neg_score: The score of negative samples.
160
+ subsampling_weight: The weight for correcting pos_score and neg_score.
161
+
162
+ Returns:
163
+ loss: The training loss for back propagation.
164
+ """
165
+ neg_score = (F.softmax(neg_score * args.adv_temp, dim=1).detach()
166
+ * F.logsigmoid(-neg_score)).sum(dim=1) # shape:[bs]
167
+ pos_score = F.logsigmoid(pos_score).view(neg_score.shape[0]) # shape:[bs]
168
+ positive_sample_loss = - pos_score.mean()
169
+ negative_sample_loss = - neg_score.mean()
170
+ loss = (positive_sample_loss + negative_sample_loss) / 2
171
+ return loss
172
+
173
+
174
+ class KGEModel(nn.Module):
175
+ def __init__(self, nentity, nrelation, hidden_dim, gamma, entity_embedding, relation_embedding):
176
+ super(KGEModel, self).__init__()
177
+ self.nentity = nentity
178
+ self.nrelation = nrelation
179
+ self.hidden_dim = hidden_dim
180
+
181
+ self.gamma = nn.Parameter(
182
+ torch.Tensor([gamma]),
183
+ requires_grad=False
184
+ )
185
+ self.entity_embedding = entity_embedding
186
+ self.relation_embedding = relation_embedding
187
+
188
+ assert self.relation_embedding.shape[0] == nrelation
189
+ assert self.entity_embedding.shape[0] == nentity
190
+
191
+ def forward(self, sample, mode='single'):
192
+ '''
193
+ Forward function that calculate the score of a batch of triples.
194
+ In the 'single' mode, sample is a batch of triple.
195
+ In the 'head-batch' or 'tail-batch' mode, sample consists two part.
196
+ The first part is usually the positive sample.
197
+ And the second part is the entities in the negative samples.
198
+ Because negative samples and positive samples usually share two elements
199
+ in their triple ((head, relation) or (relation, tail)).
200
+ '''
201
+
202
+ if mode == 'single':
203
+ batch_size, negative_sample_size = sample.size(0), 1
204
+
205
+ head = torch.index_select(
206
+ self.entity_embedding,
207
+ dim=0,
208
+ index=sample[:, 0]
209
+ ).unsqueeze(1)
210
+
211
+ relation = torch.index_select(
212
+ self.relation_embedding,
213
+ dim=0,
214
+ index=sample[:, 1]
215
+ ).unsqueeze(1)
216
+
217
+ tail = torch.index_select(
218
+ self.entity_embedding,
219
+ dim=0,
220
+ index=sample[:, 2]
221
+ ).unsqueeze(1)
222
+
223
+ elif mode == 'head-batch':
224
+ tail_part, head_part = sample
225
+ batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
226
+
227
+ head = torch.index_select(
228
+ self.entity_embedding,
229
+ dim=0,
230
+ index=head_part.view(-1)
231
+ ).view(batch_size, negative_sample_size, -1)
232
+
233
+ relation = torch.index_select(
234
+ self.relation_embedding,
235
+ dim=0,
236
+ index=tail_part[:, 1]
237
+ ).unsqueeze(1)
238
+
239
+ tail = torch.index_select(
240
+ self.entity_embedding,
241
+ dim=0,
242
+ index=tail_part[:, 2]
243
+ ).unsqueeze(1)
244
+
245
+ elif mode == 'tail-batch':
246
+ head_part, tail_part = sample
247
+ batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
248
+
249
+ head = torch.index_select(
250
+ self.entity_embedding,
251
+ dim=0,
252
+ index=head_part[:, 0]
253
+ ).unsqueeze(1)
254
+
255
+ relation = torch.index_select(
256
+ self.relation_embedding,
257
+ dim=0,
258
+ index=head_part[:, 1]
259
+ ).unsqueeze(1)
260
+
261
+ tail = torch.index_select(
262
+ self.entity_embedding,
263
+ dim=0,
264
+ index=tail_part.view(-1)
265
+ ).view(batch_size, negative_sample_size, -1)
266
+
267
+ else:
268
+ raise ValueError('mode %s not supported' % mode)
269
+
270
+ score = self.TransE(head, relation, tail, mode)
271
+
272
+ return score
273
+
274
+ def TransE(self, head, relation, tail, mode):
275
+ if mode == 'head-batch':
276
+ score = head + (relation - tail)
277
+ else:
278
+ score = (head + relation) - tail
279
+
280
+ score = self.gamma.item() - torch.norm(score, p=1, dim=-1)
281
+ return score
282
+
283
+ @torch.no_grad()
284
+ def test_step(self, test_triples, all_true_triples, args, nentity, nrelation):
285
+ '''
286
+ Evaluate the model on test or valid datasets
287
+ '''
288
+ # Otherwise use standard (filtered) MRR, MR, HITS@1, HITS@3, and HITS@10 metrics
289
+ # Prepare dataloader for evaluation
290
+ test_dataloader_head = DataLoader(
291
+ KGTestDataset(
292
+ test_triples,
293
+ all_true_triples,
294
+ nentity,
295
+ nrelation,
296
+ 'head-batch'
297
+ ),
298
+ batch_size=args.batch_size,
299
+ num_workers=args.workers,
300
+ persistent_workers=True,
301
+ collate_fn=KGTestDataset.collate_fn
302
+ )
303
+
304
+ test_dataloader_tail = DataLoader(
305
+ KGTestDataset(
306
+ test_triples,
307
+ all_true_triples,
308
+ nentity,
309
+ nrelation,
310
+ 'tail-batch'
311
+ ),
312
+ batch_size=args.batch_size,
313
+ num_workers=args.workers,
314
+ persistent_workers=True,
315
+ collate_fn=KGTestDataset.collate_fn
316
+ )
317
+
318
+ test_dataset_list = [test_dataloader_head, test_dataloader_tail]
319
+
320
+ logs = []
321
+
322
+ step = 0
323
+ total_steps = sum([len(dataset) for dataset in test_dataset_list])
324
+
325
+ # pdb.set_trace()
326
+ with tqdm(total=total_steps) as _tqdm:
327
+ _tqdm.set_description(f'eval KGC')
328
+ for test_dataset in test_dataset_list:
329
+ for positive_sample, negative_sample, filter_bias, mode in test_dataset:
330
+
331
+ positive_sample = positive_sample.cuda()
332
+ negative_sample = negative_sample.cuda()
333
+ filter_bias = filter_bias.cuda()
334
+
335
+ batch_size = positive_sample.size(0)
336
+
337
+ score = self.forward((positive_sample, negative_sample), mode)
338
+ score += filter_bias
339
+
340
+ # Explicitly sort all the entities to ensure that there is no test exposure bias
341
+ argsort = torch.argsort(score, dim=1, descending=True)
342
+
343
+ if mode == 'head-batch':
344
+ positive_arg = positive_sample[:, 0]
345
+ elif mode == 'tail-batch':
346
+ positive_arg = positive_sample[:, 2]
347
+ else:
348
+ raise ValueError('mode %s not supported' % mode)
349
+
350
+ for i in range(batch_size):
351
+ # Notice that argsort is not ranking
352
+ # ranking = (argsort[i, :] == positive_arg[i]).nonzero()
353
+ ranking = (argsort[i, :] == positive_arg[i]).nonzero(as_tuple=False)
354
+ assert ranking.size(0) == 1
355
+
356
+ # ranking + 1 is the true ranking used in evaluation metrics
357
+ ranking = 1 + ranking.item()
358
+ logs.append({
359
+ 'MRR': 1.0 / ranking,
360
+ 'MR': float(ranking),
361
+ 'HITS@1': 1.0 if ranking <= 1 else 0.0,
362
+ 'HITS@3': 1.0 if ranking <= 3 else 0.0,
363
+ 'HITS@10': 1.0 if ranking <= 10 else 0.0,
364
+ })
365
+
366
+ # if step % args.test_log_steps == 0:
367
+ # logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))
368
+ _tqdm.update(1)
369
+ _tqdm.set_description(f'KGC Eval:')
370
+ step += 1
371
+
372
+ metrics = {}
373
+ for metric in logs[0].keys():
374
+ metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
375
+
376
+ return metrics
377
+
378
+
379
+ # 专门为KGE的测试设计一个dataset
380
+ class KGTestDataset(torch.utils.data.Dataset):
381
+ def __init__(self, triples, all_true_triples, nentity, nrelation, mode, head4rel_tail=None, tail4head_rel=None):
382
+ self.len = len(triples)
383
+ self.triple_set = set(all_true_triples)
384
+ self.triples = triples
385
+
386
+ # 需要统计得到
387
+ self.nentity = nentity
388
+ self.nrelation = nrelation
389
+ self.mode = mode
390
+
391
+ # 给定关系尾实体对应头实体
392
+ # print("build head4rel_tail")
393
+ # self.head4rel_tail = self.find_head4rel_tail()
394
+ # print("build tail4head_rel")
395
+ # self.tail4head_rel = self.find_tail4head_rel()
396
+
397
+ def __len__(self):
398
+ return self.len
399
+
400
+ def find_head4rel_tail(self):
401
+ ans = defaultdict(list)
402
+ for (h, r, t) in self.triple_set:
403
+ ans[(r, t)].append(h)
404
+ return ans
405
+
406
+ def find_tail4head_rel(self):
407
+ ans = defaultdict(list)
408
+ for (h, r, t) in self.triple_set:
409
+ ans[(h, r)].append(t)
410
+ return ans
411
+
412
+ def __getitem__(self, idx):
413
+ head, relation, tail = self.triples[idx]
414
+
415
+ if self.mode == 'head-batch':
416
+ tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set
417
+ else (-100, head) for rand_head in range(self.nentity)]
418
+ tmp[head] = (0, head)
419
+ elif self.mode == 'tail-batch':
420
+ tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set
421
+ else (-100, tail) for rand_tail in range(self.nentity)]
422
+ tmp[tail] = (0, tail)
423
+ else:
424
+ raise ValueError('negative batch mode %s not supported' % self.mode)
425
+ # if self.mode == 'head-batch':
426
+ #
427
+ # tmp = [(0, rand_head) if rand_head not in self.head4rel_tail[(relation, tail)]
428
+ # else (-100, head) for rand_head in range(self.nentity)]
429
+ # tmp[head] = (0, head)
430
+ # elif self.mode == 'tail-batch':
431
+ # tmp = [(0, rand_tail) if rand_tail not in self.tail4head_rel[(head, relation)]
432
+ # else (-100, tail) for rand_tail in range(self.nentity)]
433
+ # tmp[tail] = (0, tail)
434
+ # else:
435
+ # raise ValueError('negative batch mode %s not supported' % self.mode)
436
+
437
+ tmp = torch.LongTensor(tmp)
438
+ filter_bias = tmp[:, 0].float()
439
+ negative_sample = tmp[:, 1]
440
+
441
+ positive_sample = torch.LongTensor((head, relation, tail))
442
+
443
+ return positive_sample, negative_sample, filter_bias, self.mode
444
+
445
+ @staticmethod
446
+ def collate_fn(data):
447
+ positive_sample = torch.stack([_[0] for _ in data], dim=0)
448
+ negative_sample = torch.stack([_[1] for _ in data], dim=0)
449
+ filter_bias = torch.stack([_[2] for _ in data], dim=0)
450
+ mode = data[0][3]
451
+ return positive_sample, negative_sample, filter_bias, mode
KTeleBERT/model/Numeric.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import torch
3
+ import transformers
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ import numpy as np
8
+ import pdb
9
+ import math
10
+ from .Tool_model import AutomaticWeightedLoss
11
+ import os.path as osp
12
+ import json
13
+
14
+
15
+ def ortho_penalty(t):
16
+ return ((t @ t.T - torch.eye(t.shape[0]).cuda())**2).sum()
17
+
18
+
19
+ class AttenNumeric(nn.Module):
20
+ def __init__(self, config):
21
+ super(AttenNumeric, self).__init__()
22
+ # ----------- 加载kpi2id --------------------
23
+ kpi_file_path = osp.join(config.data_path, 'kpi2id.json')
24
+
25
+ with open(kpi_file_path, 'r') as f:
26
+ # pdb.set_trace()
27
+ kpi2id = json.load(f)
28
+ config.num_kpi = 303
29
+ # config.num_kpi = len(kpi2id)
30
+ # -------------------------------
31
+
32
+ self.config = config
33
+ self.fc = nn.Linear(1, config.hidden_size)
34
+ # self.actication = nn.ReLU()
35
+ self.actication = nn.LeakyReLU()
36
+ # self.embedding = nn.Linear(config.hidden_size, self.attention_head_size)
37
+ if config.contrastive_loss:
38
+ self.loss_awl = AutomaticWeightedLoss(3, config)
39
+ else:
40
+ self.loss_awl = AutomaticWeightedLoss(2, config)
41
+ self.encoder = AttNumEncoder(config)
42
+ self.decoder = AttNumDecoder(config)
43
+ self.classifier = NumClassifier(config)
44
+ self.ce_loss = nn.CrossEntropyLoss()
45
+
46
+ def contrastive_loss(self, hidden, kpi):
47
+ # in batch negative
48
+ bs_tmp = hidden.shape[0]
49
+ eye = torch.eye(bs_tmp).cuda()
50
+ hidden = F.normalize(hidden, dim=1)
51
+ # [12,12]
52
+ # 减去对角矩阵目的是防止对自身的相似程度影响了判断
53
+ hidden_sim = (torch.matmul(hidden, hidden.T) - eye) / 0.07
54
+ kpi = kpi.expand(-1, bs_tmp)
55
+ kpi_sim = torch.abs(kpi - kpi.T) + eye
56
+ kpi_sim = torch.min(kpi_sim, 1)[1]
57
+ sc_loss = self.ce_loss(hidden_sim, kpi_sim)
58
+ return sc_loss
59
+
60
+ def _encode(self, kpi, query):
61
+ kpi_emb = self.actication(self.fc(kpi))
62
+ # name_emb = self.embedding(query)
63
+ hidden, en_loss, scalar_list = self.encoder(kpi_emb, query)
64
+
65
+ # 两个及以下的对比学习没有意义
66
+ if self.config.contrastive_loss and hidden.shape[0] > 2:
67
+ con_loss = self.contrastive_loss(hidden.squeeze(1), kpi.squeeze(1))
68
+ else:
69
+ con_loss = None
70
+ hidden = self.actication(hidden)
71
+ assert query.shape[0] > 0
72
+ return hidden, en_loss, scalar_list, con_loss
73
+
74
+ def forward(self, kpi, query, kpi_id):
75
+ hidden, en_loss, scalar_list, con_loss = self._encode(kpi, query)
76
+ dec_kpi_score, de_loss = self.decoder(kpi, hidden)
77
+ cls_kpi, cls_loss = self.classifier(hidden, kpi_id)
78
+ if con_loss is not None:
79
+ # 0.001 * con_loss
80
+ loss_sum = self.loss_awl(de_loss, cls_loss, 0.1 * con_loss)
81
+ loss_all = loss_sum + en_loss
82
+ loss_dic = {'cls_loss': cls_loss.item(), 'reg_loss': de_loss.item(), 'orth_loss': en_loss.item(), 'con_loss': con_loss.item()}
83
+ # pdb.set_trace()
84
+ else:
85
+ loss_sum = self.loss_awl(de_loss, cls_loss)
86
+ loss_all = loss_sum + en_loss
87
+ loss_dic = {'cls_loss': cls_loss.item(), 'reg_loss': de_loss.item(), 'orth_loss': en_loss.item()}
88
+
89
+ return dec_kpi_score, cls_kpi, hidden, loss_all, self.loss_awl.params.tolist(), loss_dic, scalar_list
90
+
91
+
92
+ class AttNumEncoder(nn.Module):
93
+ def __init__(self, config):
94
+ super(AttNumEncoder, self).__init__()
95
+ self.num_l_layers = config.l_layers
96
+ self.layer = nn.ModuleList([AttNumLayer(config) for _ in range(self.num_l_layers)])
97
+
98
+ def forward(self, kpi_emb, name_emb):
99
+ loss = 0.
100
+ scalar_list = []
101
+ for layer_module in self.layer:
102
+ kpi_emb, orth_loss, scalar = layer_module(kpi_emb, name_emb)
103
+ loss += orth_loss
104
+ scalar_list.append(scalar)
105
+ return kpi_emb, loss, scalar_list
106
+
107
+
108
+ class AttNumDecoder(nn.Module):
109
+ def __init__(self, config):
110
+ super(AttNumDecoder, self).__init__()
111
+ self.dense_1 = nn.Linear(config.hidden_size, config.hidden_size)
112
+ self.dense_2 = nn.Linear(config.hidden_size, 1)
113
+ self.actication = nn.LeakyReLU()
114
+ self.loss_func = nn.MSELoss(reduction='mean')
115
+
116
+ def forward(self, kpi_label, hidden):
117
+ # 修复异常值
118
+ pre = self.actication(self.dense_2(self.actication(self.dense_1(hidden))))
119
+ loss = self.loss_func(pre, kpi_label)
120
+ # pdb.set_trace()
121
+ return pre, loss
122
+
123
+
124
+ class NumClassifier(nn.Module):
125
+ def __init__(self, config):
126
+ super(NumClassifier, self).__init__()
127
+ self.dense_1 = nn.Linear(config.hidden_size, int(config.hidden_size / 3))
128
+ self.dense_2 = nn.Linear(int(config.hidden_size / 3), config.num_kpi)
129
+ self.loss_func = nn.CrossEntropyLoss()
130
+ # self.actication = nn.ReLU()
131
+ self.actication = nn.LeakyReLU()
132
+
133
+ def forward(self, hidden, kpi_id):
134
+ hidden = self.actication(self.dense_1(hidden))
135
+ pre = self.actication(self.dense_2(hidden)).squeeze(1)
136
+ loss = self.loss_func(pre, kpi_id)
137
+ return pre, loss
138
+
139
+
140
+ class AttNumLayer(nn.Module):
141
+ def __init__(self, config):
142
+ super(AttNumLayer, self).__init__()
143
+ self.config = config
144
+ # 768 / 8 = 8
145
+ self.num_attention_heads = config.num_attention_heads
146
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 96
147
+ # self.head_size = config.hidden_size
148
+
149
+ # scaler
150
+ self.scalar = nn.Parameter(.3 * torch.ones(1, requires_grad=True))
151
+ self.key = nn.Parameter(torch.empty(self.num_attention_heads, self.attention_head_size))
152
+
153
+ self.dense_down = nn.Linear(config.hidden_size, 128)
154
+ self.dense_up = nn.Linear(128, config.hidden_size)
155
+
156
+ # name embedding
157
+ self.embedding = nn.Linear(config.hidden_size, self.attention_head_size)
158
+ # num_attention_heads�� value���� ת������k��
159
+ self.value = nn.Linear(config.hidden_size, config.hidden_size * self.num_attention_heads)
160
+
161
+ # add & norm
162
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
163
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
164
+ # 0.1
165
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
166
+
167
+ # for m in self.modules().modules():
168
+ # pdb.set_trace()
169
+
170
+ nn.init.kaiming_normal_(self.key, mode='fan_out', nonlinearity='leaky_relu')
171
+ # nn.init.orthogonal_(self.key)
172
+
173
+ def transpose_for_scores(self, x):
174
+ new_x_shape = x.size()[:-1] + (
175
+ self.num_attention_heads,
176
+ self.config.hidden_size,
177
+ )
178
+ x = x.view(*new_x_shape)
179
+ return x
180
+ # return x.permute(0, 2, 1, 3)
181
+
182
+ def forward(self, kpi_emb, name_emb):
183
+ # [64, 1, 96]
184
+ name_emb = self.embedding(name_emb)
185
+
186
+ mixed_value_layer = self.value(kpi_emb)
187
+
188
+ # [64, 1, 8, 768]
189
+ value_layer = self.transpose_for_scores(mixed_value_layer)
190
+
191
+ # key: [8, 96] self.key.transpose(-1, -2): [96, 8]
192
+ # name_emb: [64, 1, 96]
193
+ attention_scores = torch.matmul(name_emb, self.key.transpose(-1, -2))
194
+ # [64, 1, 8]
195
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
196
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
197
+
198
+ # This is actually dropping out entire tokens to attend to, which might
199
+ # seem a bit unusual, but is taken from the original Transformer paper.
200
+ attention_probs = self.dropout(attention_probs)
201
+ attention_probs = attention_probs.unsqueeze(1)
202
+ # ��Ȩ��value��
203
+ # [64, 1, 1, 8] * [64, 1, 8, 768] = [64, 1, 1, 768]
204
+ context_layer = torch.matmul(attention_probs, value_layer)
205
+ # context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
206
+ new_context_layer_shape = context_layer.size()[:-2] + (self.config.hidden_size,)
207
+ context_layer = context_layer.view(*new_context_layer_shape)
208
+ # add & norm
209
+ output_emb = self.dense(context_layer)
210
+ output_emb = self.dropout(output_emb)
211
+ output_emb = self.LayerNorm(output_emb + self.scalar * self.dense_up(self.dense_down(kpi_emb)))
212
+ # output_emb = self.LayerNorm(self.LayerNorm(output_emb) + self.scalar * kpi_emb)
213
+ # pdb.set_trace()
214
+ wei = self.value.weight.chunk(8, dim=0)
215
+ orth_loss_value = sum([ortho_penalty(k) for k in wei])
216
+ # 0.01 * ortho_penalty(self.key) + ortho_penalty(self.value.weight)
217
+ orth_loss = 0.0001 * orth_loss_value + 0.0001 * ortho_penalty(self.dense.weight) + 0.01 * ((self.scalar[0])**2).sum()
218
+ return output_emb, orth_loss, self.scalar.tolist()[0]
KTeleBERT/model/OD_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import pdb
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ # from transformers import BertModel, BertTokenizer, BertForMaskedLM
8
+ import json
9
+ from packaging import version
10
+ import torch.distributed as dist
11
+
12
+
13
+ class OD_model(nn.Module):
14
+ def __init__(self, args):
15
+ super().__init__()
16
+ self.args = args
17
+ self.order_num = args.order_num
18
+ if args.od_type == 'linear_cat':
19
+ # self.order_dense_1 = nn.Linear(args.hidden_size * self.order_num, args.hidden_size)
20
+ # self.order_dense_2 = nn.Linear(args.hidden_size, 1)
21
+ self.order_dense_1 = nn.Linear(args.hidden_size * self.order_num, args.hidden_size)
22
+ if self.args.num_od_layer > 0:
23
+ self.layer = nn.ModuleList([OD_Layer_linear(args) for _ in range(args.num_od_layer)])
24
+
25
+ self.order_dense_2 = nn.Linear(args.hidden_size, 1)
26
+
27
+ self.actication = nn.LeakyReLU()
28
+ self.bn = torch.nn.BatchNorm1d(args.hidden_size)
29
+ self.dp = nn.Dropout(p=args.hidden_dropout_prob)
30
+ self.loss_func = nn.BCEWithLogitsLoss()
31
+ # self.loss_func = nn.CrossEntropyLoss()
32
+
33
+ def forward(self, input, labels):
34
+ # input 切成两半
35
+ # 换方向拼接
36
+ loss_dic = {}
37
+ pre = self.predict(input)
38
+ # pdb.set_trace()
39
+ loss = self.loss_func(pre, labels.unsqueeze(1))
40
+ loss_dic['order_loss'] = loss.item()
41
+ return loss, loss_dic
42
+
43
+ def encode(self, input):
44
+ if self.args.num_od_layer > 0:
45
+ for layer_module in self.layer:
46
+ input = layer_module(input)
47
+ inputs = torch.chunk(input, 2, dim=0)
48
+ emb = torch.concat(inputs, dim=1)
49
+ return self.actication(self.order_dense_1(self.dp(emb)))
50
+
51
+ def predict(self, input):
52
+ return self.order_dense_2(self.bn(self.encode(input)))
53
+
54
+ def right_caculate(self, input, labels, threshold=0.5):
55
+ input = input.squeeze(1).tolist()
56
+ labels = labels.tolist()
57
+ right = 0
58
+ for i in range(len(input)):
59
+ if (input[i] >= threshold and labels[i] >= 0.5) or (input[i] < threshold and labels[i] < 0.5):
60
+ right += 1
61
+ return right
62
+
63
+
64
+ class OD_Layer_linear(nn.Module):
65
+ def __init__(self, args):
66
+ super().__init__()
67
+ self.args = args
68
+ self.dense = nn.Linear(args.hidden_size, args.hidden_size)
69
+ self.actication = nn.LeakyReLU()
70
+ self.bn = torch.nn.BatchNorm1d(args.hidden_size)
71
+ self.dropout = nn.Dropout(p=args.hidden_dropout_prob)
72
+
73
+ def forward(self, input):
74
+ return self.actication(self.bn(self.dense(self.dropout(input))))
KTeleBERT/model/Tool_model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ # https://github.com/Mikoto10032/AutomaticWeightedLoss/blob/master/AutomaticWeightedLoss.py
7
+
8
+
9
+ class AutomaticWeightedLoss(nn.Module):
10
+ # '''
11
+ # automatically weighted multi-task loss
12
+ # Params��
13
+ # num: int��the number of loss
14
+ # x: multi-task loss
15
+ # Examples��
16
+ # loss1=1
17
+ # loss2=2
18
+ # awl = AutomaticWeightedLoss(2)
19
+ # loss_sum = awl(loss1, loss2)
20
+ # '''
21
+ def __init__(self, num=2, args=None):
22
+ super(AutomaticWeightedLoss, self).__init__()
23
+ if args is None or args.use_awl:
24
+ params = torch.ones(num, requires_grad=True)
25
+ self.params = torch.nn.Parameter(params)
26
+ else:
27
+ params = torch.ones(num, requires_grad=False)
28
+ self.params = torch.nn.Parameter(params, requires_grad=False)
29
+
30
+ def forward(self, *x):
31
+ loss_sum = 0
32
+ for i, loss in enumerate(x):
33
+ loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
34
+ return loss_sum
KTeleBERT/model/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .vector import Vector
2
+ # from .classifier import SimpleClassifier
3
+ # # from .updn import UpDn
4
+ # # from .ban import Ban
5
+
6
+ from .bert import (
7
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
8
+ BertForMaskedLM,
9
+ BertForMultipleChoice,
10
+ BertForNextSentencePrediction,
11
+ BertForPreTraining,
12
+ BertForQuestionAnswering,
13
+ BertForSequenceClassification,
14
+ BertForTokenClassification,
15
+ BertLayer,
16
+ BertLMHeadModel,
17
+ BertModel,
18
+ BertPreTrainedModel,
19
+ load_tf_weights_in_bert,
20
+ )
21
+
22
+ from .bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
23
+ from .bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
24
+ from .HWBert import HWBert
25
+ from .KE_model import KGEModel, KE_model
26
+ from .OD_model import OD_model
KTeleBERT/model/__pycache__/HWBert.cpython-38.pyc ADDED
Binary file (4.42 kB). View file
 
KTeleBERT/model/__pycache__/KE_model.cpython-38.pyc ADDED
Binary file (12.3 kB). View file
 
KTeleBERT/model/__pycache__/Numeric.cpython-38.pyc ADDED
Binary file (6.88 kB). View file
 
KTeleBERT/model/__pycache__/OD_model.cpython-38.pyc ADDED
Binary file (2.95 kB). View file
 
KTeleBERT/model/__pycache__/Tool_model.cpython-38.pyc ADDED
Binary file (1.01 kB). View file
 
KTeleBERT/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (849 Bytes). View file
 
KTeleBERT/model/bert/__init__.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all.
4
+
5
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from transformers.utils import (
22
+ OptionalDependencyNotAvailable,
23
+ _LazyModule,
24
+ is_flax_available,
25
+ is_tensorflow_text_available,
26
+ is_tf_available,
27
+ is_tokenizers_available,
28
+ is_torch_available,
29
+ )
30
+
31
+
32
+ _import_structure = {
33
+ "configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"],
34
+ "tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
35
+ }
36
+
37
+ try:
38
+ if not is_tokenizers_available():
39
+ raise OptionalDependencyNotAvailable()
40
+ except OptionalDependencyNotAvailable:
41
+ pass
42
+ else:
43
+ _import_structure["tokenization_bert_fast"] = ["BertTokenizerFast"]
44
+
45
+ try:
46
+ if not is_torch_available():
47
+ raise OptionalDependencyNotAvailable()
48
+ except OptionalDependencyNotAvailable:
49
+ pass
50
+ else:
51
+ _import_structure["modeling_bert"] = [
52
+ "BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
53
+ "BertForMaskedLM",
54
+ "BertForMultipleChoice",
55
+ "BertForNextSentencePrediction",
56
+ "BertForPreTraining",
57
+ "BertForQuestionAnswering",
58
+ "BertForSequenceClassification",
59
+ "BertForTokenClassification",
60
+ "BertLayer",
61
+ "BertLMHeadModel",
62
+ "BertModel",
63
+ "BertPreTrainedModel",
64
+ "load_tf_weights_in_bert",
65
+ ]
66
+
67
+ try:
68
+ if not is_tf_available():
69
+ raise OptionalDependencyNotAvailable()
70
+ except OptionalDependencyNotAvailable:
71
+ pass
72
+ else:
73
+ _import_structure["modeling_tf_bert"] = [
74
+ "TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
75
+ "TFBertEmbeddings",
76
+ "TFBertForMaskedLM",
77
+ "TFBertForMultipleChoice",
78
+ "TFBertForNextSentencePrediction",
79
+ "TFBertForPreTraining",
80
+ "TFBertForQuestionAnswering",
81
+ "TFBertForSequenceClassification",
82
+ "TFBertForTokenClassification",
83
+ "TFBertLMHeadModel",
84
+ "TFBertMainLayer",
85
+ "TFBertModel",
86
+ "TFBertPreTrainedModel",
87
+ ]
88
+ try:
89
+ if not is_tensorflow_text_available():
90
+ raise OptionalDependencyNotAvailable()
91
+ except OptionalDependencyNotAvailable:
92
+ pass
93
+ else:
94
+ _import_structure["tokenization_bert_tf"] = ["TFBertTokenizer"]
95
+
96
+ try:
97
+ if not is_flax_available():
98
+ raise OptionalDependencyNotAvailable()
99
+ except OptionalDependencyNotAvailable:
100
+ pass
101
+ else:
102
+ _import_structure["modeling_flax_bert"] = [
103
+ "FlaxBertForCausalLM",
104
+ "FlaxBertForMaskedLM",
105
+ "FlaxBertForMultipleChoice",
106
+ "FlaxBertForNextSentencePrediction",
107
+ "FlaxBertForPreTraining",
108
+ "FlaxBertForQuestionAnswering",
109
+ "FlaxBertForSequenceClassification",
110
+ "FlaxBertForTokenClassification",
111
+ "FlaxBertModel",
112
+ "FlaxBertPreTrainedModel",
113
+ ]
114
+
115
+ if TYPE_CHECKING:
116
+ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
117
+ from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
118
+
119
+ try:
120
+ if not is_tokenizers_available():
121
+ raise OptionalDependencyNotAvailable()
122
+ except OptionalDependencyNotAvailable:
123
+ pass
124
+ else:
125
+ from .tokenization_bert_fast import BertTokenizerFast
126
+
127
+ try:
128
+ if not is_torch_available():
129
+ raise OptionalDependencyNotAvailable()
130
+ except OptionalDependencyNotAvailable:
131
+ pass
132
+ else:
133
+ from .modeling_bert import (
134
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
135
+ BertForMaskedLM,
136
+ BertForMultipleChoice,
137
+ BertForNextSentencePrediction,
138
+ BertForPreTraining,
139
+ BertForQuestionAnswering,
140
+ BertForSequenceClassification,
141
+ BertForTokenClassification,
142
+ BertLayer,
143
+ BertLMHeadModel,
144
+ BertModel,
145
+ BertPreTrainedModel,
146
+ load_tf_weights_in_bert,
147
+ )
148
+
149
+ try:
150
+ if not is_tf_available():
151
+ raise OptionalDependencyNotAvailable()
152
+ except OptionalDependencyNotAvailable:
153
+ pass
154
+ else:
155
+ from .modeling_tf_bert import (
156
+ TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
157
+ TFBertEmbeddings,
158
+ TFBertForMaskedLM,
159
+ TFBertForMultipleChoice,
160
+ TFBertForNextSentencePrediction,
161
+ TFBertForPreTraining,
162
+ TFBertForQuestionAnswering,
163
+ TFBertForSequenceClassification,
164
+ TFBertForTokenClassification,
165
+ TFBertLMHeadModel,
166
+ TFBertMainLayer,
167
+ TFBertModel,
168
+ TFBertPreTrainedModel,
169
+ )
170
+
171
+ try:
172
+ if not is_tensorflow_text_available():
173
+ raise OptionalDependencyNotAvailable()
174
+ except OptionalDependencyNotAvailable:
175
+ pass
176
+ else:
177
+ from .tokenization_bert_tf import TFBertTokenizer
178
+
179
+ try:
180
+ if not is_flax_available():
181
+ raise OptionalDependencyNotAvailable()
182
+ except OptionalDependencyNotAvailable:
183
+ pass
184
+ else:
185
+ from .modeling_flax_bert import (
186
+ FlaxBertForCausalLM,
187
+ FlaxBertForMaskedLM,
188
+ FlaxBertForMultipleChoice,
189
+ FlaxBertForNextSentencePrediction,
190
+ FlaxBertForPreTraining,
191
+ FlaxBertForQuestionAnswering,
192
+ FlaxBertForSequenceClassification,
193
+ FlaxBertForTokenClassification,
194
+ FlaxBertModel,
195
+ FlaxBertPreTrainedModel,
196
+ )
197
+
198
+ else:
199
+ import sys
200
+
201
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
KTeleBERT/model/bert/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (3.22 kB). View file
 
KTeleBERT/model/bert/__pycache__/configuration_bert.cpython-38.pyc ADDED
Binary file (8.86 kB). View file
 
KTeleBERT/model/bert/__pycache__/modeling_bert.cpython-38.pyc ADDED
Binary file (57.4 kB). View file
 
KTeleBERT/model/bert/__pycache__/tokenization_bert.cpython-38.pyc ADDED
Binary file (19 kB). View file
 
KTeleBERT/model/bert/configuration_bert.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration"""
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.onnx import OnnxConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
29
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
30
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
31
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
32
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
33
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
34
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
35
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
36
+ "bert-large-uncased-whole-word-masking": (
37
+ "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json"
38
+ ),
39
+ "bert-large-cased-whole-word-masking": (
40
+ "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json"
41
+ ),
42
+ "bert-large-uncased-whole-word-masking-finetuned-squad": (
43
+ "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json"
44
+ ),
45
+ "bert-large-cased-whole-word-masking-finetuned-squad": (
46
+ "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json"
47
+ ),
48
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
49
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
50
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
51
+ "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
52
+ "cl-tohoku/bert-base-japanese-whole-word-masking": (
53
+ "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json"
54
+ ),
55
+ "cl-tohoku/bert-base-japanese-char": (
56
+ "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json"
57
+ ),
58
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": (
59
+ "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json"
60
+ ),
61
+ "TurkuNLP/bert-base-finnish-cased-v1": (
62
+ "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json"
63
+ ),
64
+ "TurkuNLP/bert-base-finnish-uncased-v1": (
65
+ "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json"
66
+ ),
67
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
68
+ # See all BERT models at https://huggingface.co/models?filter=bert
69
+ }
70
+
71
+
72
+ class BertConfig(PretrainedConfig):
73
+ r"""
74
+ This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
75
+ instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
76
+ configuration with the defaults will yield a similar configuration to that of the BERT
77
+ [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
78
+
79
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
80
+ documentation from [`PretrainedConfig`] for more information.
81
+
82
+
83
+ Args:
84
+ vocab_size (`int`, *optional*, defaults to 30522):
85
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
86
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
87
+ hidden_size (`int`, *optional*, defaults to 768):
88
+ Dimensionality of the encoder layers and the pooler layer.
89
+ num_hidden_layers (`int`, *optional*, defaults to 12):
90
+ Number of hidden layers in the Transformer encoder.
91
+ num_attention_heads (`int`, *optional*, defaults to 12):
92
+ Number of attention heads for each attention layer in the Transformer encoder.
93
+ intermediate_size (`int`, *optional*, defaults to 3072):
94
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
95
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
96
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
97
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
98
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
99
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
100
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
101
+ The dropout ratio for the attention probabilities.
102
+ max_position_embeddings (`int`, *optional*, defaults to 512):
103
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
104
+ just in case (e.g., 512 or 1024 or 2048).
105
+ type_vocab_size (`int`, *optional*, defaults to 2):
106
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
107
+ initializer_range (`float`, *optional*, defaults to 0.02):
108
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
109
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
110
+ The epsilon used by the layer normalization layers.
111
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
112
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
113
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
114
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
115
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
116
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
117
+ use_cache (`bool`, *optional*, defaults to `True`):
118
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
119
+ relevant if `config.is_decoder=True`.
120
+ classifier_dropout (`float`, *optional*):
121
+ The dropout ratio for the classification head.
122
+
123
+ Examples:
124
+
125
+ ```python
126
+ >>> from transformers import BertModel, BertConfig
127
+
128
+ >>> # Initializing a BERT bert-base-uncased style configuration
129
+ >>> configuration = BertConfig()
130
+
131
+ >>> # Initializing a model from the bert-base-uncased style configuration
132
+ >>> model = BertModel(configuration)
133
+
134
+ >>> # Accessing the model configuration
135
+ >>> configuration = model.config
136
+ ```"""
137
+ model_type = "bert"
138
+
139
+ def __init__(
140
+ self,
141
+ vocab_size=30522,
142
+ hidden_size=768,
143
+ num_hidden_layers=12,
144
+ num_attention_heads=12,
145
+ intermediate_size=3072,
146
+ hidden_act="gelu",
147
+ hidden_dropout_prob=0.1,
148
+ attention_probs_dropout_prob=0.1,
149
+ max_position_embeddings=512,
150
+ type_vocab_size=2,
151
+ initializer_range=0.02,
152
+ layer_norm_eps=1e-12,
153
+ pad_token_id=0,
154
+ position_embedding_type="absolute",
155
+ use_cache=True,
156
+ classifier_dropout=None,
157
+ **kwargs
158
+ ):
159
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
160
+
161
+ self.vocab_size = vocab_size
162
+ self.hidden_size = hidden_size
163
+ self.num_hidden_layers = num_hidden_layers
164
+ self.num_attention_heads = num_attention_heads
165
+ self.hidden_act = hidden_act
166
+ self.intermediate_size = intermediate_size
167
+ self.hidden_dropout_prob = hidden_dropout_prob
168
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
169
+ self.max_position_embeddings = max_position_embeddings
170
+ self.type_vocab_size = type_vocab_size
171
+ self.initializer_range = initializer_range
172
+ self.layer_norm_eps = layer_norm_eps
173
+ self.position_embedding_type = position_embedding_type
174
+ self.use_cache = use_cache
175
+ self.classifier_dropout = classifier_dropout
176
+
177
+
178
+ class BertOnnxConfig(OnnxConfig):
179
+ @property
180
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
181
+ if self.task == "multiple-choice":
182
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
183
+ else:
184
+ dynamic_axis = {0: "batch", 1: "sequence"}
185
+ return OrderedDict(
186
+ [
187
+ ("input_ids", dynamic_axis),
188
+ ("attention_mask", dynamic_axis),
189
+ ("token_type_ids", dynamic_axis),
190
+ ]
191
+ )
KTeleBERT/model/bert/modeling_bert.py ADDED
@@ -0,0 +1,2010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ import pdb
19
+ import math
20
+ import os
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from packaging import version
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ BaseModelOutputWithPoolingAndCrossAttentions,
36
+ CausalLMOutputWithCrossAttentions,
37
+ MaskedLMOutput,
38
+ MultipleChoiceModelOutput,
39
+ NextSentencePredictorOutput,
40
+ QuestionAnsweringModelOutput,
41
+ SequenceClassifierOutput,
42
+ TokenClassifierOutput,
43
+ )
44
+ from transformers.modeling_utils import PreTrainedModel
45
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
46
+ from transformers.utils import (
47
+ ModelOutput,
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ logging,
52
+ replace_return_docstrings,
53
+ )
54
+ from .configuration_bert import BertConfig
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
60
+ _CONFIG_FOR_DOC = "BertConfig"
61
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
62
+
63
+ # TokenClassification docstring
64
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
65
+ _TOKEN_CLASS_EXPECTED_OUTPUT = (
66
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
67
+ )
68
+ _TOKEN_CLASS_EXPECTED_LOSS = 0.01
69
+
70
+ # QuestionAnswering docstring
71
+ _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
72
+ _QA_EXPECTED_OUTPUT = "'a nice puppet'"
73
+ _QA_EXPECTED_LOSS = 7.41
74
+ _QA_TARGET_START_INDEX = 14
75
+ _QA_TARGET_END_INDEX = 15
76
+
77
+ # SequenceClassification docstring
78
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
79
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
80
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
81
+
82
+
83
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
84
+ "bert-base-uncased",
85
+ "bert-large-uncased",
86
+ "bert-base-cased",
87
+ "bert-large-cased",
88
+ "bert-base-multilingual-uncased",
89
+ "bert-base-multilingual-cased",
90
+ "bert-base-chinese",
91
+ "bert-base-german-cased",
92
+ "bert-large-uncased-whole-word-masking",
93
+ "bert-large-cased-whole-word-masking",
94
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
95
+ "bert-large-cased-whole-word-masking-finetuned-squad",
96
+ "bert-base-cased-finetuned-mrpc",
97
+ "bert-base-german-dbmdz-cased",
98
+ "bert-base-german-dbmdz-uncased",
99
+ "cl-tohoku/bert-base-japanese",
100
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
101
+ "cl-tohoku/bert-base-japanese-char",
102
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
103
+ "TurkuNLP/bert-base-finnish-cased-v1",
104
+ "TurkuNLP/bert-base-finnish-uncased-v1",
105
+ "wietsedv/bert-base-dutch-cased",
106
+ # See all BERT models at https://huggingface.co/models?filter=bert
107
+ ]
108
+
109
+
110
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
111
+ """Load tf checkpoints in a pytorch model."""
112
+ try:
113
+ import re
114
+
115
+ import numpy as np
116
+ import tensorflow as tf
117
+ except ImportError:
118
+ logger.error(
119
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
120
+ "https://www.tensorflow.org/install/ for installation instructions."
121
+ )
122
+ raise
123
+ tf_path = os.path.abspath(tf_checkpoint_path)
124
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
125
+ # Load weights from TF model
126
+ init_vars = tf.train.list_variables(tf_path)
127
+ names = []
128
+ arrays = []
129
+ for name, shape in init_vars:
130
+ logger.info(f"Loading TF weight {name} with shape {shape}")
131
+ array = tf.train.load_variable(tf_path, name)
132
+ names.append(name)
133
+ arrays.append(array)
134
+
135
+ for name, array in zip(names, arrays):
136
+ name = name.split("/")
137
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
138
+ # which are not required for using pretrained model
139
+ if any(
140
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
141
+ for n in name
142
+ ):
143
+ logger.info(f"Skipping {'/'.join(name)}")
144
+ continue
145
+ pointer = model
146
+ for m_name in name:
147
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
148
+ scope_names = re.split(r"_(\d+)", m_name)
149
+ else:
150
+ scope_names = [m_name]
151
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
152
+ pointer = getattr(pointer, "weight")
153
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
154
+ pointer = getattr(pointer, "bias")
155
+ elif scope_names[0] == "output_weights":
156
+ pointer = getattr(pointer, "weight")
157
+ elif scope_names[0] == "squad":
158
+ pointer = getattr(pointer, "classifier")
159
+ else:
160
+ try:
161
+ pointer = getattr(pointer, scope_names[0])
162
+ except AttributeError:
163
+ logger.info(f"Skipping {'/'.join(name)}")
164
+ continue
165
+ if len(scope_names) >= 2:
166
+ num = int(scope_names[1])
167
+ pointer = pointer[num]
168
+ if m_name[-11:] == "_embeddings":
169
+ pointer = getattr(pointer, "weight")
170
+ elif m_name == "kernel":
171
+ array = np.transpose(array)
172
+ try:
173
+ if pointer.shape != array.shape:
174
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
175
+ except AssertionError as e:
176
+ e.args += (pointer.shape, array.shape)
177
+ raise
178
+ logger.info(f"Initialize PyTorch weight {name}")
179
+ pointer.data = torch.from_numpy(array)
180
+ return model
181
+
182
+
183
+ class BertEmbeddings(nn.Module):
184
+ """Construct the embeddings from word, position and token_type embeddings."""
185
+
186
+ def __init__(self, config):
187
+ super().__init__()
188
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
189
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
190
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
191
+
192
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
193
+ # any TensorFlow checkpoint file
194
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
195
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
196
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
197
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
198
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
199
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
200
+ self.register_buffer(
201
+ "token_type_ids",
202
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
203
+ persistent=False,
204
+ )
205
+
206
+ def forward(
207
+ self,
208
+ input_ids: Optional[torch.LongTensor] = None,
209
+ token_type_ids: Optional[torch.LongTensor] = None,
210
+ position_ids: Optional[torch.LongTensor] = None,
211
+ inputs_embeds: Optional[torch.FloatTensor] = None,
212
+ past_key_values_length: int = 0,
213
+ kpi_ref = None, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
214
+ kpi_model = None,
215
+ ) -> torch.Tensor:
216
+ if input_ids is not None:
217
+ input_shape = input_ids.size()
218
+ else:
219
+ input_shape = inputs_embeds.size()[:-1]
220
+
221
+ seq_length = input_shape[1]
222
+
223
+ if position_ids is None:
224
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
225
+
226
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
227
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
228
+ # issue #5664
229
+ if token_type_ids is None:
230
+ if hasattr(self, "token_type_ids"):
231
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
232
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
233
+ token_type_ids = buffered_token_type_ids_expanded
234
+ else:
235
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
236
+
237
+
238
+ if inputs_embeds is None:
239
+ inputs_embeds = self.word_embeddings(input_ids)
240
+
241
+ # pdb.set_trace()
242
+ # TODO 得到 KPI的name embedding,pooling,输入数值编码模型,得到特征替换(mask+)特定位置原向量,
243
+ # 不产生新的embedding,直接读取生成的embedding
244
+ en_loss, scalar_list, con_loss, numeric_input, kpi_input = None, None, None, None, None
245
+ if kpi_ref is not None:
246
+ max_len = inputs_embeds.shape[1]
247
+ # 生成数值embedding
248
+ numeric_list = []
249
+ kpi_emb_list = []
250
+ kpi_id_list = []
251
+ for i in range(len(kpi_ref)):
252
+ if len(kpi_ref[i])>0:
253
+ for item in kpi_ref[i]:
254
+ # 可能[NUM]被截断了
255
+ if item[2]>=max_len:
256
+ continue
257
+ numeric_list.append(item[4])
258
+ kpi_id_list.append(item[3])
259
+ # requires_grad=True
260
+ kpi_name_embedding = torch.mean(inputs_embeds[i][item[0]:item[1]+1], dim=0)
261
+ kpi_emb_list.append(kpi_name_embedding)
262
+ # 有可能出现没有KPI的情况
263
+ if len(kpi_emb_list)>0:
264
+ kpi_emb = torch.stack(kpi_emb_list).unsqueeze(1)
265
+
266
+ # , dtype=torch.float64
267
+ numeric_input = torch.Tensor(numeric_list).unsqueeze(1).unsqueeze(1).cuda()
268
+ kpi_input = torch.tensor(kpi_id_list, dtype=torch.long).cuda()
269
+ # pdb.set_trace()
270
+ hidden, en_loss, scalar_list, con_loss = kpi_model._encode(numeric_input, kpi_emb)
271
+ # 替换
272
+ key = 0
273
+ for i in range(len(kpi_ref)):
274
+ if len(kpi_ref[i])>0:
275
+ for item in kpi_ref[i]:
276
+ if item[2]>=max_len:
277
+ continue
278
+ # [NUM]的(x,y)坐标位置
279
+ inputs_embeds[i,item[2]] = hidden[key][0]
280
+ key += 1
281
+ assert key == hidden.shape[0]
282
+
283
+
284
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
285
+
286
+ embeddings = inputs_embeds + token_type_embeddings
287
+ if self.position_embedding_type == "absolute":
288
+ position_embeddings = self.position_embeddings(position_ids)
289
+ embeddings += position_embeddings
290
+ embeddings = self.LayerNorm(embeddings)
291
+ embeddings = self.dropout(embeddings)
292
+ # 重构输出
293
+ return embeddings, en_loss, scalar_list, con_loss, numeric_input, kpi_input
294
+ # return embeddings
295
+
296
+
297
+ class BertSelfAttention(nn.Module):
298
+ def __init__(self, config, position_embedding_type=None):
299
+ super().__init__()
300
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
301
+ raise ValueError(
302
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
303
+ f"heads ({config.num_attention_heads})"
304
+ )
305
+
306
+ self.num_attention_heads = config.num_attention_heads
307
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
308
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
309
+
310
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
311
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
312
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
313
+
314
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
315
+ self.position_embedding_type = position_embedding_type or getattr(
316
+ config, "position_embedding_type", "absolute"
317
+ )
318
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
319
+ self.max_position_embeddings = config.max_position_embeddings
320
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
321
+
322
+ self.is_decoder = config.is_decoder
323
+
324
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
325
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
326
+ x = x.view(new_x_shape)
327
+ return x.permute(0, 2, 1, 3)
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.Tensor,
332
+ attention_mask: Optional[torch.FloatTensor] = None,
333
+ head_mask: Optional[torch.FloatTensor] = None,
334
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
335
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
336
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
337
+ output_attentions: Optional[bool] = False,
338
+ ) -> Tuple[torch.Tensor]:
339
+ mixed_query_layer = self.query(hidden_states)
340
+
341
+ # If this is instantiated as a cross-attention module, the keys
342
+ # and values come from an encoder; the attention mask needs to be
343
+ # such that the encoder's padding tokens are not attended to.
344
+ is_cross_attention = encoder_hidden_states is not None
345
+
346
+ if is_cross_attention and past_key_value is not None:
347
+ # reuse k,v, cross_attentions
348
+ key_layer = past_key_value[0]
349
+ value_layer = past_key_value[1]
350
+ attention_mask = encoder_attention_mask
351
+ elif is_cross_attention:
352
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
353
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
354
+ attention_mask = encoder_attention_mask
355
+ elif past_key_value is not None:
356
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
357
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
358
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
359
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
360
+ else:
361
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
362
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
363
+
364
+ query_layer = self.transpose_for_scores(mixed_query_layer)
365
+
366
+ if self.is_decoder:
367
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
368
+ # Further calls to cross_attention layer can then reuse all cross-attention
369
+ # key/value_states (first "if" case)
370
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
371
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
372
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
373
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
374
+ past_key_value = (key_layer, value_layer)
375
+
376
+ # Take the dot product between "query" and "key" to get the raw attention scores.
377
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
378
+
379
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
380
+ seq_length = hidden_states.size()[1]
381
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
382
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
383
+ distance = position_ids_l - position_ids_r
384
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
385
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
386
+
387
+ if self.position_embedding_type == "relative_key":
388
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
389
+ attention_scores = attention_scores + relative_position_scores
390
+ elif self.position_embedding_type == "relative_key_query":
391
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
392
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
393
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
394
+
395
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
396
+ if attention_mask is not None:
397
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
398
+ attention_scores = attention_scores + attention_mask
399
+
400
+ # Normalize the attention scores to probabilities.
401
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
402
+
403
+ # This is actually dropping out entire tokens to attend to, which might
404
+ # seem a bit unusual, but is taken from the original Transformer paper.
405
+ attention_probs = self.dropout(attention_probs)
406
+
407
+ # Mask heads if we want to
408
+ if head_mask is not None:
409
+ attention_probs = attention_probs * head_mask
410
+
411
+ context_layer = torch.matmul(attention_probs, value_layer)
412
+
413
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
414
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
415
+ context_layer = context_layer.view(new_context_layer_shape)
416
+
417
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
418
+
419
+ if self.is_decoder:
420
+ outputs = outputs + (past_key_value,)
421
+ return outputs
422
+
423
+
424
+ class BertSelfOutput(nn.Module):
425
+ def __init__(self, config):
426
+ super().__init__()
427
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
428
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
429
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
430
+
431
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
432
+ hidden_states = self.dense(hidden_states)
433
+ hidden_states = self.dropout(hidden_states)
434
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
435
+ return hidden_states
436
+
437
+
438
+ class BertAttention(nn.Module):
439
+ def __init__(self, config, position_embedding_type=None):
440
+ super().__init__()
441
+ self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
442
+ self.output = BertSelfOutput(config)
443
+ self.pruned_heads = set()
444
+
445
+ def prune_heads(self, heads):
446
+ if len(heads) == 0:
447
+ return
448
+ heads, index = find_pruneable_heads_and_indices(
449
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
450
+ )
451
+
452
+ # Prune linear layers
453
+ self.self.query = prune_linear_layer(self.self.query, index)
454
+ self.self.key = prune_linear_layer(self.self.key, index)
455
+ self.self.value = prune_linear_layer(self.self.value, index)
456
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
457
+
458
+ # Update hyper params and store pruned heads
459
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
460
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
461
+ self.pruned_heads = self.pruned_heads.union(heads)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ attention_mask: Optional[torch.FloatTensor] = None,
467
+ head_mask: Optional[torch.FloatTensor] = None,
468
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
469
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
470
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
471
+ output_attentions: Optional[bool] = False,
472
+ ) -> Tuple[torch.Tensor]:
473
+ self_outputs = self.self(
474
+ hidden_states,
475
+ attention_mask,
476
+ head_mask,
477
+ encoder_hidden_states,
478
+ encoder_attention_mask,
479
+ past_key_value,
480
+ output_attentions,
481
+ )
482
+ attention_output = self.output(self_outputs[0], hidden_states)
483
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
484
+ return outputs
485
+
486
+
487
+ class BertIntermediate(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
491
+ if isinstance(config.hidden_act, str):
492
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
493
+ else:
494
+ self.intermediate_act_fn = config.hidden_act
495
+
496
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
497
+ hidden_states = self.dense(hidden_states)
498
+ hidden_states = self.intermediate_act_fn(hidden_states)
499
+ return hidden_states
500
+
501
+
502
+ class BertOutput(nn.Module):
503
+ def __init__(self, config):
504
+ super().__init__()
505
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
506
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
507
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
508
+
509
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
510
+ hidden_states = self.dense(hidden_states)
511
+ hidden_states = self.dropout(hidden_states)
512
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
513
+ return hidden_states
514
+
515
+
516
+ class BertLayer(nn.Module):
517
+ def __init__(self, config):
518
+ super().__init__()
519
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
520
+ self.seq_len_dim = 1
521
+ self.attention = BertAttention(config)
522
+ self.is_decoder = config.is_decoder
523
+ self.add_cross_attention = config.add_cross_attention
524
+ if self.add_cross_attention:
525
+ if not self.is_decoder:
526
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
527
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
528
+ self.intermediate = BertIntermediate(config)
529
+ self.output = BertOutput(config)
530
+
531
+ def forward(
532
+ self,
533
+ hidden_states: torch.Tensor,
534
+ attention_mask: Optional[torch.FloatTensor] = None,
535
+ head_mask: Optional[torch.FloatTensor] = None,
536
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
537
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
538
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
539
+ output_attentions: Optional[bool] = False,
540
+ ) -> Tuple[torch.Tensor]:
541
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
542
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
543
+ self_attention_outputs = self.attention(
544
+ hidden_states,
545
+ attention_mask,
546
+ head_mask,
547
+ output_attentions=output_attentions,
548
+ past_key_value=self_attn_past_key_value,
549
+ )
550
+ attention_output = self_attention_outputs[0]
551
+
552
+ # if decoder, the last output is tuple of self-attn cache
553
+ if self.is_decoder:
554
+ outputs = self_attention_outputs[1:-1]
555
+ present_key_value = self_attention_outputs[-1]
556
+ else:
557
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
558
+
559
+ cross_attn_present_key_value = None
560
+ if self.is_decoder and encoder_hidden_states is not None:
561
+ if not hasattr(self, "crossattention"):
562
+ raise ValueError(
563
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
564
+ " by setting `config.add_cross_attention=True`"
565
+ )
566
+
567
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
568
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
569
+ cross_attention_outputs = self.crossattention(
570
+ attention_output,
571
+ attention_mask,
572
+ head_mask,
573
+ encoder_hidden_states,
574
+ encoder_attention_mask,
575
+ cross_attn_past_key_value,
576
+ output_attentions,
577
+ )
578
+ attention_output = cross_attention_outputs[0]
579
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
580
+
581
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
582
+ cross_attn_present_key_value = cross_attention_outputs[-1]
583
+ present_key_value = present_key_value + cross_attn_present_key_value
584
+
585
+ layer_output = apply_chunking_to_forward(
586
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
587
+ )
588
+ outputs = (layer_output,) + outputs
589
+
590
+ # if decoder, return the attn key/values as the last output
591
+ if self.is_decoder:
592
+ outputs = outputs + (present_key_value,)
593
+
594
+ return outputs
595
+
596
+ def feed_forward_chunk(self, attention_output):
597
+ intermediate_output = self.intermediate(attention_output)
598
+ layer_output = self.output(intermediate_output, attention_output)
599
+ return layer_output
600
+
601
+
602
+ class BertEncoder(nn.Module):
603
+ def __init__(self, config):
604
+ super().__init__()
605
+ self.config = config
606
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
607
+ self.gradient_checkpointing = False
608
+
609
+ def forward(
610
+ self,
611
+ hidden_states: torch.Tensor,
612
+ attention_mask: Optional[torch.FloatTensor] = None,
613
+ head_mask: Optional[torch.FloatTensor] = None,
614
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
615
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
616
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
+ use_cache: Optional[bool] = None,
618
+ output_attentions: Optional[bool] = False,
619
+ output_hidden_states: Optional[bool] = False,
620
+ return_dict: Optional[bool] = True,
621
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
622
+ all_hidden_states = () if output_hidden_states else None
623
+ all_self_attentions = () if output_attentions else None
624
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
625
+
626
+ next_decoder_cache = () if use_cache else None
627
+ for i, layer_module in enumerate(self.layer):
628
+ if output_hidden_states:
629
+ all_hidden_states = all_hidden_states + (hidden_states,)
630
+
631
+ layer_head_mask = head_mask[i] if head_mask is not None else None
632
+ past_key_value = past_key_values[i] if past_key_values is not None else None
633
+
634
+ if self.gradient_checkpointing and self.training:
635
+
636
+ if use_cache:
637
+ logger.warning(
638
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
639
+ )
640
+ use_cache = False
641
+
642
+ def create_custom_forward(module):
643
+ def custom_forward(*inputs):
644
+ return module(*inputs, past_key_value, output_attentions)
645
+
646
+ return custom_forward
647
+
648
+ layer_outputs = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(layer_module),
650
+ hidden_states,
651
+ attention_mask,
652
+ layer_head_mask,
653
+ encoder_hidden_states,
654
+ encoder_attention_mask,
655
+ )
656
+ else:
657
+ layer_outputs = layer_module(
658
+ hidden_states,
659
+ attention_mask,
660
+ layer_head_mask,
661
+ encoder_hidden_states,
662
+ encoder_attention_mask,
663
+ past_key_value,
664
+ output_attentions,
665
+ )
666
+
667
+ hidden_states = layer_outputs[0]
668
+ if use_cache:
669
+ next_decoder_cache += (layer_outputs[-1],)
670
+ if output_attentions:
671
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
672
+ if self.config.add_cross_attention:
673
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
674
+
675
+ if output_hidden_states:
676
+ all_hidden_states = all_hidden_states + (hidden_states,)
677
+
678
+ if not return_dict:
679
+ return tuple(
680
+ v
681
+ for v in [
682
+ hidden_states,
683
+ next_decoder_cache,
684
+ all_hidden_states,
685
+ all_self_attentions,
686
+ all_cross_attentions,
687
+ ]
688
+ if v is not None
689
+ )
690
+ return BaseModelOutputWithPastAndCrossAttentions(
691
+ last_hidden_state=hidden_states,
692
+ past_key_values=next_decoder_cache,
693
+ hidden_states=all_hidden_states,
694
+ attentions=all_self_attentions,
695
+ cross_attentions=all_cross_attentions,
696
+ )
697
+
698
+
699
+ class BertPooler(nn.Module):
700
+ def __init__(self, config):
701
+ super().__init__()
702
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
703
+ self.activation = nn.Tanh()
704
+
705
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
706
+ # We "pool" the model by simply taking the hidden state corresponding
707
+ # to the first token.
708
+ first_token_tensor = hidden_states[:, 0]
709
+ pooled_output = self.dense(first_token_tensor)
710
+ pooled_output = self.activation(pooled_output)
711
+ return pooled_output
712
+
713
+
714
+ class BertPredictionHeadTransform(nn.Module):
715
+ def __init__(self, config):
716
+ super().__init__()
717
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
718
+ if isinstance(config.hidden_act, str):
719
+ self.transform_act_fn = ACT2FN[config.hidden_act]
720
+ else:
721
+ self.transform_act_fn = config.hidden_act
722
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
723
+
724
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
725
+ hidden_states = self.dense(hidden_states)
726
+ hidden_states = self.transform_act_fn(hidden_states)
727
+ hidden_states = self.LayerNorm(hidden_states)
728
+ return hidden_states
729
+
730
+
731
+ class BertLMPredictionHead(nn.Module):
732
+ def __init__(self, config):
733
+ super().__init__()
734
+ self.transform = BertPredictionHeadTransform(config)
735
+
736
+ # The output weights are the same as the input embeddings, but there is
737
+ # an output-only bias for each token.
738
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
739
+
740
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
741
+
742
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
743
+ self.decoder.bias = self.bias
744
+
745
+ def forward(self, hidden_states):
746
+ hidden_states = self.transform(hidden_states)
747
+ hidden_states = self.decoder(hidden_states)
748
+ return hidden_states
749
+
750
+
751
+ class BertOnlyMLMHead(nn.Module):
752
+ def __init__(self, config):
753
+ super().__init__()
754
+ self.predictions = BertLMPredictionHead(config)
755
+
756
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
757
+ prediction_scores = self.predictions(sequence_output)
758
+ return prediction_scores
759
+
760
+
761
+ class BertOnlyNSPHead(nn.Module):
762
+ def __init__(self, config):
763
+ super().__init__()
764
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
765
+
766
+ def forward(self, pooled_output):
767
+ seq_relationship_score = self.seq_relationship(pooled_output)
768
+ return seq_relationship_score
769
+
770
+
771
+ class BertPreTrainingHeads(nn.Module):
772
+ def __init__(self, config):
773
+ super().__init__()
774
+ self.predictions = BertLMPredictionHead(config)
775
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
776
+
777
+ def forward(self, sequence_output, pooled_output):
778
+ prediction_scores = self.predictions(sequence_output)
779
+ seq_relationship_score = self.seq_relationship(pooled_output)
780
+ return prediction_scores, seq_relationship_score
781
+
782
+
783
+ class BertPreTrainedModel(PreTrainedModel):
784
+ """
785
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
786
+ models.
787
+ """
788
+
789
+ config_class = BertConfig
790
+ load_tf_weights = load_tf_weights_in_bert
791
+ base_model_prefix = "bert"
792
+ supports_gradient_checkpointing = True
793
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
794
+
795
+ def _init_weights(self, module):
796
+ """Initialize the weights"""
797
+ if isinstance(module, nn.Linear):
798
+ # Slightly different from the TF version which uses truncated_normal for initialization
799
+ # cf https://github.com/pytorch/pytorch/pull/5617
800
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
801
+ if module.bias is not None:
802
+ module.bias.data.zero_()
803
+ elif isinstance(module, nn.Embedding):
804
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
805
+ if module.padding_idx is not None:
806
+ module.weight.data[module.padding_idx].zero_()
807
+ elif isinstance(module, nn.LayerNorm):
808
+ module.bias.data.zero_()
809
+ module.weight.data.fill_(1.0)
810
+
811
+ def _set_gradient_checkpointing(self, module, value=False):
812
+ if isinstance(module, BertEncoder):
813
+ module.gradient_checkpointing = value
814
+
815
+
816
+ @dataclass
817
+ class BertForPreTrainingOutput(ModelOutput):
818
+ """
819
+ Output type of [`BertForPreTraining`].
820
+
821
+ Args:
822
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
823
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
824
+ (classification) loss.
825
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
826
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
827
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
828
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
829
+ before SoftMax).
830
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
831
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
832
+ shape `(batch_size, sequence_length, hidden_size)`.
833
+
834
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
835
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
836
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
837
+ sequence_length)`.
838
+
839
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
840
+ heads.
841
+ """
842
+
843
+ loss: Optional[torch.FloatTensor] = None
844
+ prediction_logits: torch.FloatTensor = None
845
+ seq_relationship_logits: torch.FloatTensor = None
846
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
847
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
848
+
849
+
850
+ BERT_START_DOCSTRING = r"""
851
+
852
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
853
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
854
+ etc.)
855
+
856
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
857
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
858
+ and behavior.
859
+
860
+ Parameters:
861
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
862
+ Initializing with a config file does not load the weights associated with the model, only the
863
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
864
+ """
865
+
866
+ BERT_INPUTS_DOCSTRING = r"""
867
+ Args:
868
+ input_ids (`torch.LongTensor` of shape `({0})`):
869
+ Indices of input sequence tokens in the vocabulary.
870
+
871
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
872
+ [`PreTrainedTokenizer.__call__`] for details.
873
+
874
+ [What are input IDs?](../glossary#input-ids)
875
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
876
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
877
+
878
+ - 1 for tokens that are **not masked**,
879
+ - 0 for tokens that are **masked**.
880
+
881
+ [What are attention masks?](../glossary#attention-mask)
882
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
883
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
884
+ 1]`:
885
+
886
+ - 0 corresponds to a *sentence A* token,
887
+ - 1 corresponds to a *sentence B* token.
888
+
889
+ [What are token type IDs?](../glossary#token-type-ids)
890
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
891
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
892
+ config.max_position_embeddings - 1]`.
893
+
894
+ [What are position IDs?](../glossary#position-ids)
895
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
896
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
897
+
898
+ - 1 indicates the head is **not masked**,
899
+ - 0 indicates the head is **masked**.
900
+
901
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
902
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
903
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
904
+ model's internal embedding lookup matrix.
905
+ output_attentions (`bool`, *optional*):
906
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
907
+ tensors for more detail.
908
+ output_hidden_states (`bool`, *optional*):
909
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
910
+ more detail.
911
+ return_dict (`bool`, *optional*):
912
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
913
+ """
914
+
915
+
916
+ @add_start_docstrings(
917
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
918
+ BERT_START_DOCSTRING,
919
+ )
920
+ class BertModel(BertPreTrainedModel):
921
+ """
922
+
923
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
924
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
925
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
926
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
927
+
928
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
929
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
930
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
931
+ """
932
+
933
+ def __init__(self, config, add_pooling_layer=True):
934
+ super().__init__(config)
935
+ self.config = config
936
+
937
+ self.embeddings = BertEmbeddings(config)
938
+ self.encoder = BertEncoder(config)
939
+
940
+ self.pooler = BertPooler(config) if add_pooling_layer else None
941
+
942
+ # Initialize weights and apply final processing
943
+ self.post_init()
944
+
945
+ def get_input_embeddings(self):
946
+ return self.embeddings.word_embeddings
947
+
948
+ def set_input_embeddings(self, value):
949
+ self.embeddings.word_embeddings = value
950
+
951
+ def _prune_heads(self, heads_to_prune):
952
+ """
953
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
954
+ class PreTrainedModel
955
+ """
956
+ for layer, heads in heads_to_prune.items():
957
+ self.encoder.layer[layer].attention.prune_heads(heads)
958
+
959
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
960
+ @add_code_sample_docstrings(
961
+ processor_class=_TOKENIZER_FOR_DOC,
962
+ checkpoint=_CHECKPOINT_FOR_DOC,
963
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
964
+ config_class=_CONFIG_FOR_DOC,
965
+ )
966
+ def forward(
967
+ self,
968
+ input_ids: Optional[torch.Tensor] = None,
969
+ attention_mask: Optional[torch.Tensor] = None,
970
+ token_type_ids: Optional[torch.Tensor] = None,
971
+ position_ids: Optional[torch.Tensor] = None,
972
+ head_mask: Optional[torch.Tensor] = None,
973
+ inputs_embeds: Optional[torch.Tensor] = None,
974
+ encoder_hidden_states: Optional[torch.Tensor] = None,
975
+ encoder_attention_mask: Optional[torch.Tensor] = None,
976
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
977
+ use_cache: Optional[bool] = None,
978
+ output_attentions: Optional[bool] = None,
979
+ output_hidden_states: Optional[bool] = None,
980
+ return_dict: Optional[bool] = None,
981
+ kpi_ref = None, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
982
+ kpi_model = None, # 输入KPI模型
983
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
984
+ r"""
985
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
986
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
987
+ the model is configured as a decoder.
988
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
989
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
990
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
991
+
992
+ - 1 for tokens that are **not masked**,
993
+ - 0 for tokens that are **masked**.
994
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
995
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
996
+
997
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
998
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
999
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1000
+ use_cache (`bool`, *optional*):
1001
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1002
+ `past_key_values`).
1003
+ """
1004
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1005
+ output_hidden_states = (
1006
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1007
+ )
1008
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1009
+
1010
+ if self.config.is_decoder:
1011
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1012
+ else:
1013
+ use_cache = False
1014
+
1015
+ if input_ids is not None and inputs_embeds is not None:
1016
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1017
+ elif input_ids is not None:
1018
+ input_shape = input_ids.size()
1019
+ elif inputs_embeds is not None:
1020
+ input_shape = inputs_embeds.size()[:-1]
1021
+ else:
1022
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1023
+
1024
+ batch_size, seq_length = input_shape
1025
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1026
+
1027
+ # past_key_values_length
1028
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1029
+
1030
+ if attention_mask is None:
1031
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1032
+
1033
+ if token_type_ids is None:
1034
+ if hasattr(self.embeddings, "token_type_ids"):
1035
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1036
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1037
+ token_type_ids = buffered_token_type_ids_expanded
1038
+ else:
1039
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1040
+
1041
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1042
+ # ourselves in which case we just need to make it broadcastable to all heads.
1043
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1044
+
1045
+ # If a 2D or 3D attention mask is provided for the cross-attention
1046
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1047
+ if self.config.is_decoder and encoder_hidden_states is not None:
1048
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1049
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1050
+ if encoder_attention_mask is None:
1051
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1052
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1053
+ else:
1054
+ encoder_extended_attention_mask = None
1055
+
1056
+ # Prepare head mask if needed
1057
+ # 1.0 in head_mask indicate we keep the head
1058
+ # attention_probs has shape bsz x n_heads x N x N
1059
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1060
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1061
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1062
+
1063
+ embedding_output, en_loss, scalar_list, con_loss, numeric_input, kpi_input = self.embeddings(
1064
+ input_ids=input_ids,
1065
+ position_ids=position_ids,
1066
+ token_type_ids=token_type_ids,
1067
+ inputs_embeds=inputs_embeds,
1068
+ past_key_values_length=past_key_values_length,
1069
+ kpi_ref=kpi_ref, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
1070
+ kpi_model=kpi_model,
1071
+ )
1072
+
1073
+ # 输入了占位符的位置信息
1074
+ # KPI的起始,结束位置embedding 的 pooling
1075
+
1076
+ # 在这里按位置替换数值embedding
1077
+ # 同时用KPI的 embedding 作为监督信号
1078
+ #
1079
+ encoder_outputs = self.encoder(
1080
+ embedding_output,
1081
+ attention_mask=extended_attention_mask,
1082
+ head_mask=head_mask,
1083
+ encoder_hidden_states=encoder_hidden_states,
1084
+ encoder_attention_mask=encoder_extended_attention_mask,
1085
+ past_key_values=past_key_values,
1086
+ use_cache=use_cache,
1087
+ output_attentions=output_attentions,
1088
+ output_hidden_states=output_hidden_states,
1089
+ return_dict=return_dict,
1090
+ )
1091
+
1092
+
1093
+ # 在这里对数值embedding的位置做回归loss
1094
+ sequence_output = encoder_outputs[0]
1095
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1096
+
1097
+ if not return_dict:
1098
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1099
+
1100
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1101
+ last_hidden_state=sequence_output,
1102
+ pooler_output=pooled_output,
1103
+ past_key_values=encoder_outputs.past_key_values,
1104
+ hidden_states=encoder_outputs.hidden_states,
1105
+ attentions=encoder_outputs.attentions,
1106
+ cross_attentions=encoder_outputs.cross_attentions,
1107
+ ), en_loss, scalar_list, con_loss, numeric_input, kpi_input
1108
+
1109
+
1110
+ @add_start_docstrings(
1111
+ """
1112
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1113
+ sentence prediction (classification)` head.
1114
+ """,
1115
+ BERT_START_DOCSTRING,
1116
+ )
1117
+ class BertForPreTraining(BertPreTrainedModel):
1118
+ def __init__(self, config):
1119
+ super().__init__(config)
1120
+
1121
+ self.bert = BertModel(config)
1122
+ self.cls = BertPreTrainingHeads(config)
1123
+
1124
+ # Initialize weights and apply final processing
1125
+ self.post_init()
1126
+
1127
+ def get_output_embeddings(self):
1128
+ return self.cls.predictions.decoder
1129
+
1130
+ def set_output_embeddings(self, new_embeddings):
1131
+ self.cls.predictions.decoder = new_embeddings
1132
+
1133
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1134
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1135
+ def forward(
1136
+ self,
1137
+ input_ids: Optional[torch.Tensor] = None,
1138
+ attention_mask: Optional[torch.Tensor] = None,
1139
+ token_type_ids: Optional[torch.Tensor] = None,
1140
+ position_ids: Optional[torch.Tensor] = None,
1141
+ head_mask: Optional[torch.Tensor] = None,
1142
+ inputs_embeds: Optional[torch.Tensor] = None,
1143
+ labels: Optional[torch.Tensor] = None,
1144
+ next_sentence_label: Optional[torch.Tensor] = None,
1145
+ output_attentions: Optional[bool] = None,
1146
+ output_hidden_states: Optional[bool] = None,
1147
+ return_dict: Optional[bool] = None,
1148
+ ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
1149
+ r"""
1150
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1151
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1152
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1153
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1154
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1155
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1156
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1157
+
1158
+ - 0 indicates sequence B is a continuation of sequence A,
1159
+ - 1 indicates sequence B is a random sequence.
1160
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1161
+ Used to hide legacy arguments that have been deprecated.
1162
+
1163
+ Returns:
1164
+
1165
+ Example:
1166
+
1167
+ ```python
1168
+ >>> from transformers import BertTokenizer, BertForPreTraining
1169
+ >>> import torch
1170
+
1171
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1172
+ >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
1173
+
1174
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1175
+ >>> outputs = model(**inputs)
1176
+
1177
+ >>> prediction_logits = outputs.prediction_logits
1178
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1179
+ ```
1180
+ """
1181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1182
+
1183
+ outputs = self.bert(
1184
+ input_ids,
1185
+ attention_mask=attention_mask,
1186
+ token_type_ids=token_type_ids,
1187
+ position_ids=position_ids,
1188
+ head_mask=head_mask,
1189
+ inputs_embeds=inputs_embeds,
1190
+ output_attentions=output_attentions,
1191
+ output_hidden_states=output_hidden_states,
1192
+ return_dict=return_dict,
1193
+ )
1194
+
1195
+ sequence_output, pooled_output = outputs[:2]
1196
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1197
+
1198
+ total_loss = None
1199
+ if labels is not None and next_sentence_label is not None:
1200
+ loss_fct = CrossEntropyLoss()
1201
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1202
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1203
+ total_loss = masked_lm_loss + next_sentence_loss
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1207
+ return ((total_loss,) + output) if total_loss is not None else output
1208
+
1209
+ return BertForPreTrainingOutput(
1210
+ loss=total_loss,
1211
+ prediction_logits=prediction_scores,
1212
+ seq_relationship_logits=seq_relationship_score,
1213
+ hidden_states=outputs.hidden_states,
1214
+ attentions=outputs.attentions,
1215
+ )
1216
+
1217
+
1218
+ @add_start_docstrings(
1219
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
1220
+ )
1221
+ class BertLMHeadModel(BertPreTrainedModel):
1222
+
1223
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1224
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1225
+
1226
+ def __init__(self, config):
1227
+ super().__init__(config)
1228
+
1229
+ if not config.is_decoder:
1230
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1231
+
1232
+ self.bert = BertModel(config, add_pooling_layer=False)
1233
+ self.cls = BertOnlyMLMHead(config)
1234
+
1235
+ # Initialize weights and apply final processing
1236
+ self.post_init()
1237
+
1238
+ def get_output_embeddings(self):
1239
+ return self.cls.predictions.decoder
1240
+
1241
+ def set_output_embeddings(self, new_embeddings):
1242
+ self.cls.predictions.decoder = new_embeddings
1243
+
1244
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1245
+ @add_code_sample_docstrings(
1246
+ processor_class=_TOKENIZER_FOR_DOC,
1247
+ checkpoint=_CHECKPOINT_FOR_DOC,
1248
+ output_type=CausalLMOutputWithCrossAttentions,
1249
+ config_class=_CONFIG_FOR_DOC,
1250
+ )
1251
+ def forward(
1252
+ self,
1253
+ input_ids: Optional[torch.Tensor] = None,
1254
+ attention_mask: Optional[torch.Tensor] = None,
1255
+ token_type_ids: Optional[torch.Tensor] = None,
1256
+ position_ids: Optional[torch.Tensor] = None,
1257
+ head_mask: Optional[torch.Tensor] = None,
1258
+ inputs_embeds: Optional[torch.Tensor] = None,
1259
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1260
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1261
+ labels: Optional[torch.Tensor] = None,
1262
+ past_key_values: Optional[List[torch.Tensor]] = None,
1263
+ use_cache: Optional[bool] = None,
1264
+ output_attentions: Optional[bool] = None,
1265
+ output_hidden_states: Optional[bool] = None,
1266
+ return_dict: Optional[bool] = None,
1267
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1268
+ r"""
1269
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1270
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1271
+ the model is configured as a decoder.
1272
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1273
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1274
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1275
+
1276
+ - 1 for tokens that are **not masked**,
1277
+ - 0 for tokens that are **masked**.
1278
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1279
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1280
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1281
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1282
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1283
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1284
+
1285
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1286
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1287
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1288
+ use_cache (`bool`, *optional*):
1289
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1290
+ `past_key_values`).
1291
+ """
1292
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1293
+ if labels is not None:
1294
+ use_cache = False
1295
+
1296
+ outputs = self.bert(
1297
+ input_ids,
1298
+ attention_mask=attention_mask,
1299
+ token_type_ids=token_type_ids,
1300
+ position_ids=position_ids,
1301
+ head_mask=head_mask,
1302
+ inputs_embeds=inputs_embeds,
1303
+ encoder_hidden_states=encoder_hidden_states,
1304
+ encoder_attention_mask=encoder_attention_mask,
1305
+ past_key_values=past_key_values,
1306
+ use_cache=use_cache,
1307
+ output_attentions=output_attentions,
1308
+ output_hidden_states=output_hidden_states,
1309
+ return_dict=return_dict,
1310
+ )
1311
+
1312
+ sequence_output = outputs[0]
1313
+ prediction_scores = self.cls(sequence_output)
1314
+
1315
+ lm_loss = None
1316
+ if labels is not None:
1317
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1318
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1319
+ labels = labels[:, 1:].contiguous()
1320
+ loss_fct = CrossEntropyLoss()
1321
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1322
+
1323
+ if not return_dict:
1324
+ output = (prediction_scores,) + outputs[2:]
1325
+ return ((lm_loss,) + output) if lm_loss is not None else output
1326
+
1327
+ return CausalLMOutputWithCrossAttentions(
1328
+ loss=lm_loss,
1329
+ logits=prediction_scores,
1330
+ past_key_values=outputs.past_key_values,
1331
+ hidden_states=outputs.hidden_states,
1332
+ attentions=outputs.attentions,
1333
+ cross_attentions=outputs.cross_attentions,
1334
+ )
1335
+
1336
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1337
+ input_shape = input_ids.shape
1338
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1339
+ if attention_mask is None:
1340
+ attention_mask = input_ids.new_ones(input_shape)
1341
+
1342
+ # cut decoder_input_ids if past is used
1343
+ if past is not None:
1344
+ input_ids = input_ids[:, -1:]
1345
+
1346
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1347
+
1348
+ def _reorder_cache(self, past, beam_idx):
1349
+ reordered_past = ()
1350
+ for layer_past in past:
1351
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1352
+ return reordered_past
1353
+
1354
+
1355
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1356
+ class BertForMaskedLM(BertPreTrainedModel):
1357
+
1358
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1359
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1360
+
1361
+ def __init__(self, config):
1362
+ super().__init__(config)
1363
+
1364
+ if config.is_decoder:
1365
+ logger.warning(
1366
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1367
+ "bi-directional self-attention."
1368
+ )
1369
+
1370
+ self.bert = BertModel(config, add_pooling_layer=False)
1371
+
1372
+ # CZ: 添加了pooling
1373
+ # self.bert = BertModel(config)
1374
+ self.cls = BertOnlyMLMHead(config)
1375
+
1376
+ # Initialize weights and apply final processing
1377
+ self.post_init()
1378
+
1379
+ def get_output_embeddings(self):
1380
+ return self.cls.predictions.decoder
1381
+
1382
+ def set_output_embeddings(self, new_embeddings):
1383
+ self.cls.predictions.decoder = new_embeddings
1384
+
1385
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1386
+ @add_code_sample_docstrings(
1387
+ processor_class=_TOKENIZER_FOR_DOC,
1388
+ checkpoint=_CHECKPOINT_FOR_DOC,
1389
+ output_type=MaskedLMOutput,
1390
+ config_class=_CONFIG_FOR_DOC,
1391
+ expected_output="'paris'",
1392
+ expected_loss=0.88,
1393
+ )
1394
+ def forward(
1395
+ self,
1396
+ input_ids: Optional[torch.Tensor] = None,
1397
+ attention_mask: Optional[torch.Tensor] = None,
1398
+ token_type_ids: Optional[torch.Tensor] = None,
1399
+ position_ids: Optional[torch.Tensor] = None,
1400
+ head_mask: Optional[torch.Tensor] = None,
1401
+ inputs_embeds: Optional[torch.Tensor] = None,
1402
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1403
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1404
+ labels: Optional[torch.Tensor] = None,
1405
+ output_attentions: Optional[bool] = None,
1406
+ output_hidden_states: Optional[bool] = None,
1407
+ return_dict: Optional[bool] = None,
1408
+ kpi_ref = None, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
1409
+ kpi_model = None, # 输入KPI模型
1410
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1411
+ r"""
1412
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1413
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1414
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1415
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1416
+ """
1417
+
1418
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1419
+
1420
+ outputs, en_loss, scalar_list, con_loss, numeric_input, kpi_input = self.bert(
1421
+ input_ids,
1422
+ attention_mask=attention_mask,
1423
+ token_type_ids=token_type_ids,
1424
+ position_ids=position_ids,
1425
+ head_mask=head_mask,
1426
+ inputs_embeds=inputs_embeds,
1427
+ encoder_hidden_states=encoder_hidden_states,
1428
+ encoder_attention_mask=encoder_attention_mask,
1429
+ output_attentions=output_attentions,
1430
+ output_hidden_states=output_hidden_states,
1431
+ return_dict=return_dict,
1432
+ kpi_ref=kpi_ref, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
1433
+ kpi_model=kpi_model,
1434
+ )
1435
+
1436
+
1437
+ # decode 等loss 在这里计算
1438
+ # 数值位本来就被maskl不参与loss计算,可以单独算loss
1439
+ sequence_output = outputs[0]
1440
+
1441
+ # 可能出现没有kpi的情况
1442
+ # if kpi_ref is None:
1443
+ # kpi_loss = None
1444
+ # else:
1445
+ # # awl需要
1446
+ # kpi_loss = torch.tensor([0.]).cuda()
1447
+ kpi_loss, kpi_loss_weight, kpi_loss_dict = None, None, None
1448
+ if kpi_input is not None:
1449
+ max_len = sequence_output.shape[1]
1450
+ # 生成数值embedding
1451
+ kpi_emb_list = []
1452
+ for i in range(len(kpi_ref)):
1453
+ if len(kpi_ref[i])>0:
1454
+ for item in kpi_ref[i]:
1455
+ # 可能[NUM]被截断了
1456
+ if item[2]>=max_len:
1457
+ continue
1458
+ # requires_grad=True
1459
+ kpi_emb_list.append(sequence_output[i][item[2]])
1460
+
1461
+ # TODO: 把KPI con loss 归一化,因为KPI会浮动
1462
+ kpi_emb = torch.stack(kpi_emb_list).unsqueeze(1)
1463
+
1464
+ # numeric_input: 相关的数值
1465
+ # kpi_input:相关的KPI id
1466
+ _dec_kpi_score, de_loss = kpi_model.decoder(numeric_input, kpi_emb)
1467
+ # pdb.set_trace()
1468
+ _cls_kpi, cls_loss = kpi_model.classifier(kpi_emb, kpi_input)
1469
+ # pdb.set_trace()
1470
+ # pdb.set_trace()
1471
+ # 提前乘一个系数降低影响
1472
+ if con_loss is not None:
1473
+ kpi_loss = kpi_model.loss_awl(de_loss, 0.2 * cls_loss, 0.2 * con_loss) + 0.5 * en_loss
1474
+ kpi_loss_dict = {'de_loss':de_loss.item(), 'con_loss':con_loss.item(), 'cls_loss':cls_loss.item(), 'en_loss':en_loss.item()}
1475
+ else:
1476
+ kpi_loss = kpi_model.loss_awl(de_loss, 0.1 * cls_loss) + 0.5 * en_loss
1477
+ kpi_loss_dict = {'de_loss':de_loss.item(), 'cls_loss':cls_loss.item(), 'en_loss':en_loss.item()}
1478
+ kpi_loss_weight = kpi_model.loss_awl.params.tolist()
1479
+
1480
+
1481
+ prediction_scores = self.cls(sequence_output)
1482
+
1483
+ masked_lm_loss = None
1484
+ if labels is not None:
1485
+ loss_fct = CrossEntropyLoss() # -100 index = padding token [ignore_index=- 100]
1486
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1487
+
1488
+
1489
+ if not return_dict:
1490
+ output = (prediction_scores,) + outputs[2:]
1491
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1492
+
1493
+ # pdb.set_trace()
1494
+ return MaskedLMOutput(
1495
+ loss=masked_lm_loss,
1496
+ logits=prediction_scores,
1497
+ hidden_states=outputs.hidden_states,
1498
+ attentions=outputs.attentions
1499
+ ), kpi_loss, kpi_loss_weight, kpi_loss_dict
1500
+
1501
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1502
+ input_shape = input_ids.shape
1503
+ effective_batch_size = input_shape[0]
1504
+
1505
+ # add a dummy token
1506
+ if self.config.pad_token_id is None:
1507
+ raise ValueError("The PAD token should be defined for generation")
1508
+
1509
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1510
+ dummy_token = torch.full(
1511
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1512
+ )
1513
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1514
+
1515
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1516
+
1517
+
1518
+ @add_start_docstrings(
1519
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1520
+ BERT_START_DOCSTRING,
1521
+ )
1522
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1523
+ def __init__(self, config):
1524
+ super().__init__(config)
1525
+
1526
+ self.bert = BertModel(config)
1527
+ self.cls = BertOnlyNSPHead(config)
1528
+
1529
+ # Initialize weights and apply final processing
1530
+ self.post_init()
1531
+
1532
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1533
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1534
+ def forward(
1535
+ self,
1536
+ input_ids: Optional[torch.Tensor] = None,
1537
+ attention_mask: Optional[torch.Tensor] = None,
1538
+ token_type_ids: Optional[torch.Tensor] = None,
1539
+ position_ids: Optional[torch.Tensor] = None,
1540
+ head_mask: Optional[torch.Tensor] = None,
1541
+ inputs_embeds: Optional[torch.Tensor] = None,
1542
+ labels: Optional[torch.Tensor] = None,
1543
+ output_attentions: Optional[bool] = None,
1544
+ output_hidden_states: Optional[bool] = None,
1545
+ return_dict: Optional[bool] = None,
1546
+ **kwargs,
1547
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1548
+ r"""
1549
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1550
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1551
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1552
+
1553
+ - 0 indicates sequence B is a continuation of sequence A,
1554
+ - 1 indicates sequence B is a random sequence.
1555
+
1556
+ Returns:
1557
+
1558
+ Example:
1559
+
1560
+ ```python
1561
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1562
+ >>> import torch
1563
+
1564
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1565
+ >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
1566
+
1567
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1568
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1569
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1570
+
1571
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1572
+ >>> logits = outputs.logits
1573
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1574
+ ```
1575
+ """
1576
+
1577
+ if "next_sentence_label" in kwargs:
1578
+ warnings.warn(
1579
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1580
+ " `labels` instead.",
1581
+ FutureWarning,
1582
+ )
1583
+ labels = kwargs.pop("next_sentence_label")
1584
+
1585
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1586
+
1587
+ outputs = self.bert(
1588
+ input_ids,
1589
+ attention_mask=attention_mask,
1590
+ token_type_ids=token_type_ids,
1591
+ position_ids=position_ids,
1592
+ head_mask=head_mask,
1593
+ inputs_embeds=inputs_embeds,
1594
+ output_attentions=output_attentions,
1595
+ output_hidden_states=output_hidden_states,
1596
+ return_dict=return_dict,
1597
+ )
1598
+
1599
+ pooled_output = outputs[1]
1600
+
1601
+ seq_relationship_scores = self.cls(pooled_output)
1602
+
1603
+ next_sentence_loss = None
1604
+ if labels is not None:
1605
+ loss_fct = CrossEntropyLoss()
1606
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1607
+
1608
+ if not return_dict:
1609
+ output = (seq_relationship_scores,) + outputs[2:]
1610
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1611
+
1612
+ return NextSentencePredictorOutput(
1613
+ loss=next_sentence_loss,
1614
+ logits=seq_relationship_scores,
1615
+ hidden_states=outputs.hidden_states,
1616
+ attentions=outputs.attentions,
1617
+ )
1618
+
1619
+
1620
+ @add_start_docstrings(
1621
+ """
1622
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1623
+ output) e.g. for GLUE tasks.
1624
+ """,
1625
+ BERT_START_DOCSTRING,
1626
+ )
1627
+ class BertForSequenceClassification(BertPreTrainedModel):
1628
+ def __init__(self, config):
1629
+ super().__init__(config)
1630
+ self.num_labels = config.num_labels
1631
+ self.config = config
1632
+
1633
+ self.bert = BertModel(config)
1634
+ classifier_dropout = (
1635
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1636
+ )
1637
+ self.dropout = nn.Dropout(classifier_dropout)
1638
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1639
+
1640
+ # Initialize weights and apply final processing
1641
+ self.post_init()
1642
+
1643
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1644
+ @add_code_sample_docstrings(
1645
+ processor_class=_TOKENIZER_FOR_DOC,
1646
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1647
+ output_type=SequenceClassifierOutput,
1648
+ config_class=_CONFIG_FOR_DOC,
1649
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1650
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1651
+ )
1652
+ def forward(
1653
+ self,
1654
+ input_ids: Optional[torch.Tensor] = None,
1655
+ attention_mask: Optional[torch.Tensor] = None,
1656
+ token_type_ids: Optional[torch.Tensor] = None,
1657
+ position_ids: Optional[torch.Tensor] = None,
1658
+ head_mask: Optional[torch.Tensor] = None,
1659
+ inputs_embeds: Optional[torch.Tensor] = None,
1660
+ labels: Optional[torch.Tensor] = None,
1661
+ output_attentions: Optional[bool] = None,
1662
+ output_hidden_states: Optional[bool] = None,
1663
+ return_dict: Optional[bool] = None,
1664
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1665
+ r"""
1666
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1667
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1668
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1669
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1670
+ """
1671
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1672
+
1673
+ outputs = self.bert(
1674
+ input_ids,
1675
+ attention_mask=attention_mask,
1676
+ token_type_ids=token_type_ids,
1677
+ position_ids=position_ids,
1678
+ head_mask=head_mask,
1679
+ inputs_embeds=inputs_embeds,
1680
+ output_attentions=output_attentions,
1681
+ output_hidden_states=output_hidden_states,
1682
+ return_dict=return_dict,
1683
+ )
1684
+
1685
+ pooled_output = outputs[1]
1686
+
1687
+ pooled_output = self.dropout(pooled_output)
1688
+ logits = self.classifier(pooled_output)
1689
+
1690
+ loss = None
1691
+ if labels is not None:
1692
+ if self.config.problem_type is None:
1693
+ if self.num_labels == 1:
1694
+ self.config.problem_type = "regression"
1695
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1696
+ self.config.problem_type = "single_label_classification"
1697
+ else:
1698
+ self.config.problem_type = "multi_label_classification"
1699
+
1700
+ if self.config.problem_type == "regression":
1701
+ loss_fct = MSELoss()
1702
+ if self.num_labels == 1:
1703
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1704
+ else:
1705
+ loss = loss_fct(logits, labels)
1706
+ elif self.config.problem_type == "single_label_classification":
1707
+ loss_fct = CrossEntropyLoss()
1708
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1709
+ elif self.config.problem_type == "multi_label_classification":
1710
+ loss_fct = BCEWithLogitsLoss()
1711
+ loss = loss_fct(logits, labels)
1712
+ if not return_dict:
1713
+ output = (logits,) + outputs[2:]
1714
+ return ((loss,) + output) if loss is not None else output
1715
+
1716
+ return SequenceClassifierOutput(
1717
+ loss=loss,
1718
+ logits=logits,
1719
+ hidden_states=outputs.hidden_states,
1720
+ attentions=outputs.attentions,
1721
+ )
1722
+
1723
+
1724
+ @add_start_docstrings(
1725
+ """
1726
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1727
+ softmax) e.g. for RocStories/SWAG tasks.
1728
+ """,
1729
+ BERT_START_DOCSTRING,
1730
+ )
1731
+ class BertForMultipleChoice(BertPreTrainedModel):
1732
+ def __init__(self, config):
1733
+ super().__init__(config)
1734
+
1735
+ self.bert = BertModel(config)
1736
+ classifier_dropout = (
1737
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1738
+ )
1739
+ self.dropout = nn.Dropout(classifier_dropout)
1740
+ self.classifier = nn.Linear(config.hidden_size, 1)
1741
+
1742
+ # Initialize weights and apply final processing
1743
+ self.post_init()
1744
+
1745
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1746
+ @add_code_sample_docstrings(
1747
+ processor_class=_TOKENIZER_FOR_DOC,
1748
+ checkpoint=_CHECKPOINT_FOR_DOC,
1749
+ output_type=MultipleChoiceModelOutput,
1750
+ config_class=_CONFIG_FOR_DOC,
1751
+ )
1752
+ def forward(
1753
+ self,
1754
+ input_ids: Optional[torch.Tensor] = None,
1755
+ attention_mask: Optional[torch.Tensor] = None,
1756
+ token_type_ids: Optional[torch.Tensor] = None,
1757
+ position_ids: Optional[torch.Tensor] = None,
1758
+ head_mask: Optional[torch.Tensor] = None,
1759
+ inputs_embeds: Optional[torch.Tensor] = None,
1760
+ labels: Optional[torch.Tensor] = None,
1761
+ output_attentions: Optional[bool] = None,
1762
+ output_hidden_states: Optional[bool] = None,
1763
+ return_dict: Optional[bool] = None,
1764
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1765
+ r"""
1766
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1767
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1768
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1769
+ `input_ids` above)
1770
+ """
1771
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1772
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1773
+
1774
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1775
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1776
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1777
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1778
+ inputs_embeds = (
1779
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1780
+ if inputs_embeds is not None
1781
+ else None
1782
+ )
1783
+
1784
+ outputs = self.bert(
1785
+ input_ids,
1786
+ attention_mask=attention_mask,
1787
+ token_type_ids=token_type_ids,
1788
+ position_ids=position_ids,
1789
+ head_mask=head_mask,
1790
+ inputs_embeds=inputs_embeds,
1791
+ output_attentions=output_attentions,
1792
+ output_hidden_states=output_hidden_states,
1793
+ return_dict=return_dict,
1794
+ )
1795
+
1796
+ pooled_output = outputs[1]
1797
+
1798
+ pooled_output = self.dropout(pooled_output)
1799
+ logits = self.classifier(pooled_output)
1800
+ reshaped_logits = logits.view(-1, num_choices)
1801
+
1802
+ loss = None
1803
+ if labels is not None:
1804
+ loss_fct = CrossEntropyLoss()
1805
+ loss = loss_fct(reshaped_logits, labels)
1806
+
1807
+ if not return_dict:
1808
+ output = (reshaped_logits,) + outputs[2:]
1809
+ return ((loss,) + output) if loss is not None else output
1810
+
1811
+ return MultipleChoiceModelOutput(
1812
+ loss=loss,
1813
+ logits=reshaped_logits,
1814
+ hidden_states=outputs.hidden_states,
1815
+ attentions=outputs.attentions,
1816
+ )
1817
+
1818
+
1819
+ @add_start_docstrings(
1820
+ """
1821
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1822
+ Named-Entity-Recognition (NER) tasks.
1823
+ """,
1824
+ BERT_START_DOCSTRING,
1825
+ )
1826
+ class BertForTokenClassification(BertPreTrainedModel):
1827
+
1828
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1829
+
1830
+ def __init__(self, config):
1831
+ super().__init__(config)
1832
+ self.num_labels = config.num_labels
1833
+
1834
+ self.bert = BertModel(config, add_pooling_layer=False)
1835
+ classifier_dropout = (
1836
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1837
+ )
1838
+ self.dropout = nn.Dropout(classifier_dropout)
1839
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1840
+
1841
+ # Initialize weights and apply final processing
1842
+ self.post_init()
1843
+
1844
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1845
+ @add_code_sample_docstrings(
1846
+ processor_class=_TOKENIZER_FOR_DOC,
1847
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
1848
+ output_type=TokenClassifierOutput,
1849
+ config_class=_CONFIG_FOR_DOC,
1850
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
1851
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
1852
+ )
1853
+ def forward(
1854
+ self,
1855
+ input_ids: Optional[torch.Tensor] = None,
1856
+ attention_mask: Optional[torch.Tensor] = None,
1857
+ token_type_ids: Optional[torch.Tensor] = None,
1858
+ position_ids: Optional[torch.Tensor] = None,
1859
+ head_mask: Optional[torch.Tensor] = None,
1860
+ inputs_embeds: Optional[torch.Tensor] = None,
1861
+ labels: Optional[torch.Tensor] = None,
1862
+ output_attentions: Optional[bool] = None,
1863
+ output_hidden_states: Optional[bool] = None,
1864
+ return_dict: Optional[bool] = None,
1865
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1866
+ r"""
1867
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1868
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1869
+ """
1870
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1871
+
1872
+ outputs = self.bert(
1873
+ input_ids,
1874
+ attention_mask=attention_mask,
1875
+ token_type_ids=token_type_ids,
1876
+ position_ids=position_ids,
1877
+ head_mask=head_mask,
1878
+ inputs_embeds=inputs_embeds,
1879
+ output_attentions=output_attentions,
1880
+ output_hidden_states=output_hidden_states,
1881
+ return_dict=return_dict,
1882
+ )
1883
+
1884
+ sequence_output = outputs[0]
1885
+
1886
+ sequence_output = self.dropout(sequence_output)
1887
+ logits = self.classifier(sequence_output)
1888
+
1889
+ loss = None
1890
+ if labels is not None:
1891
+ loss_fct = CrossEntropyLoss()
1892
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1893
+
1894
+ if not return_dict:
1895
+ output = (logits,) + outputs[2:]
1896
+ return ((loss,) + output) if loss is not None else output
1897
+
1898
+ return TokenClassifierOutput(
1899
+ loss=loss,
1900
+ logits=logits,
1901
+ hidden_states=outputs.hidden_states,
1902
+ attentions=outputs.attentions,
1903
+ )
1904
+
1905
+
1906
+ @add_start_docstrings(
1907
+ """
1908
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1909
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1910
+ """,
1911
+ BERT_START_DOCSTRING,
1912
+ )
1913
+ class BertForQuestionAnswering(BertPreTrainedModel):
1914
+
1915
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1916
+
1917
+ def __init__(self, config):
1918
+ super().__init__(config)
1919
+ self.num_labels = config.num_labels
1920
+
1921
+ self.bert = BertModel(config, add_pooling_layer=False)
1922
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1923
+
1924
+ # Initialize weights and apply final processing
1925
+ self.post_init()
1926
+
1927
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1928
+ @add_code_sample_docstrings(
1929
+ processor_class=_TOKENIZER_FOR_DOC,
1930
+ checkpoint=_CHECKPOINT_FOR_QA,
1931
+ output_type=QuestionAnsweringModelOutput,
1932
+ config_class=_CONFIG_FOR_DOC,
1933
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1934
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1935
+ expected_output=_QA_EXPECTED_OUTPUT,
1936
+ expected_loss=_QA_EXPECTED_LOSS,
1937
+ )
1938
+ def forward(
1939
+ self,
1940
+ input_ids: Optional[torch.Tensor] = None,
1941
+ attention_mask: Optional[torch.Tensor] = None,
1942
+ token_type_ids: Optional[torch.Tensor] = None,
1943
+ position_ids: Optional[torch.Tensor] = None,
1944
+ head_mask: Optional[torch.Tensor] = None,
1945
+ inputs_embeds: Optional[torch.Tensor] = None,
1946
+ start_positions: Optional[torch.Tensor] = None,
1947
+ end_positions: Optional[torch.Tensor] = None,
1948
+ output_attentions: Optional[bool] = None,
1949
+ output_hidden_states: Optional[bool] = None,
1950
+ return_dict: Optional[bool] = None,
1951
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1952
+ r"""
1953
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1954
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1955
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1956
+ are not taken into account for computing the loss.
1957
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1958
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1959
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1960
+ are not taken into account for computing the loss.
1961
+ """
1962
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1963
+
1964
+ outputs = self.bert(
1965
+ input_ids,
1966
+ attention_mask=attention_mask,
1967
+ token_type_ids=token_type_ids,
1968
+ position_ids=position_ids,
1969
+ head_mask=head_mask,
1970
+ inputs_embeds=inputs_embeds,
1971
+ output_attentions=output_attentions,
1972
+ output_hidden_states=output_hidden_states,
1973
+ return_dict=return_dict,
1974
+ )
1975
+
1976
+ sequence_output = outputs[0]
1977
+
1978
+ logits = self.qa_outputs(sequence_output)
1979
+ start_logits, end_logits = logits.split(1, dim=-1)
1980
+ start_logits = start_logits.squeeze(-1).contiguous()
1981
+ end_logits = end_logits.squeeze(-1).contiguous()
1982
+
1983
+ total_loss = None
1984
+ if start_positions is not None and end_positions is not None:
1985
+ # If we are on multi-GPU, split add a dimension
1986
+ if len(start_positions.size()) > 1:
1987
+ start_positions = start_positions.squeeze(-1)
1988
+ if len(end_positions.size()) > 1:
1989
+ end_positions = end_positions.squeeze(-1)
1990
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1991
+ ignored_index = start_logits.size(1)
1992
+ start_positions = start_positions.clamp(0, ignored_index)
1993
+ end_positions = end_positions.clamp(0, ignored_index)
1994
+
1995
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1996
+ start_loss = loss_fct(start_logits, start_positions)
1997
+ end_loss = loss_fct(end_logits, end_positions)
1998
+ total_loss = (start_loss + end_loss) / 2
1999
+
2000
+ if not return_dict:
2001
+ output = (start_logits, end_logits) + outputs[2:]
2002
+ return ((total_loss,) + output) if total_loss is not None else output
2003
+
2004
+ return QuestionAnsweringModelOutput(
2005
+ loss=total_loss,
2006
+ start_logits=start_logits,
2007
+ end_logits=end_logits,
2008
+ hidden_states=outputs.hidden_states,
2009
+ attentions=outputs.attentions,
2010
+ )
KTeleBERT/model/bert/tokenization_bert.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Bert."""
16
+
17
+
18
+ import collections
19
+ import os
20
+ import unicodedata
21
+ from typing import List, Optional, Tuple
22
+
23
+ from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24
+ from transformers.utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
34
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
35
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
36
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
37
+ "bert-base-multilingual-uncased": (
38
+ "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt"
39
+ ),
40
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
41
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
42
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
43
+ "bert-large-uncased-whole-word-masking": (
44
+ "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt"
45
+ ),
46
+ "bert-large-cased-whole-word-masking": (
47
+ "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt"
48
+ ),
49
+ "bert-large-uncased-whole-word-masking-finetuned-squad": (
50
+ "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
51
+ ),
52
+ "bert-large-cased-whole-word-masking-finetuned-squad": (
53
+ "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
54
+ ),
55
+ "bert-base-cased-finetuned-mrpc": (
56
+ "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt"
57
+ ),
58
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
59
+ "bert-base-german-dbmdz-uncased": (
60
+ "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"
61
+ ),
62
+ "TurkuNLP/bert-base-finnish-cased-v1": (
63
+ "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt"
64
+ ),
65
+ "TurkuNLP/bert-base-finnish-uncased-v1": (
66
+ "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt"
67
+ ),
68
+ "wietsedv/bert-base-dutch-cased": (
69
+ "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt"
70
+ ),
71
+ }
72
+ }
73
+
74
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
75
+ "bert-base-uncased": 512,
76
+ "bert-large-uncased": 512,
77
+ "bert-base-cased": 512,
78
+ "bert-large-cased": 512,
79
+ "bert-base-multilingual-uncased": 512,
80
+ "bert-base-multilingual-cased": 512,
81
+ "bert-base-chinese": 512,
82
+ "bert-base-german-cased": 512,
83
+ "bert-large-uncased-whole-word-masking": 512,
84
+ "bert-large-cased-whole-word-masking": 512,
85
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
86
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
87
+ "bert-base-cased-finetuned-mrpc": 512,
88
+ "bert-base-german-dbmdz-cased": 512,
89
+ "bert-base-german-dbmdz-uncased": 512,
90
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
91
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
92
+ "wietsedv/bert-base-dutch-cased": 512,
93
+ }
94
+
95
+ PRETRAINED_INIT_CONFIGURATION = {
96
+ "bert-base-uncased": {"do_lower_case": True},
97
+ "bert-large-uncased": {"do_lower_case": True},
98
+ "bert-base-cased": {"do_lower_case": False},
99
+ "bert-large-cased": {"do_lower_case": False},
100
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
101
+ "bert-base-multilingual-cased": {"do_lower_case": False},
102
+ "bert-base-chinese": {"do_lower_case": False},
103
+ "bert-base-german-cased": {"do_lower_case": False},
104
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
105
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
106
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
107
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
108
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
109
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
110
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
111
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
112
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
113
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
114
+ }
115
+
116
+
117
+ def load_vocab(vocab_file):
118
+ """Loads a vocabulary file into a dictionary."""
119
+ vocab = collections.OrderedDict()
120
+ with open(vocab_file, "r", encoding="utf-8") as reader:
121
+ tokens = reader.readlines()
122
+ for index, token in enumerate(tokens):
123
+ token = token.rstrip("\n")
124
+ vocab[token] = index
125
+ return vocab
126
+
127
+
128
+ def whitespace_tokenize(text):
129
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
130
+ text = text.strip()
131
+ if not text:
132
+ return []
133
+ tokens = text.split()
134
+ return tokens
135
+
136
+
137
+ class BertTokenizer(PreTrainedTokenizer):
138
+ r"""
139
+ Construct a BERT tokenizer. Based on WordPiece.
140
+
141
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
142
+ this superclass for more information regarding those methods.
143
+
144
+ Args:
145
+ vocab_file (`str`):
146
+ File containing the vocabulary.
147
+ do_lower_case (`bool`, *optional*, defaults to `True`):
148
+ Whether or not to lowercase the input when tokenizing.
149
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
150
+ Whether or not to do basic tokenization before WordPiece.
151
+ never_split (`Iterable`, *optional*):
152
+ Collection of tokens which will never be split during tokenization. Only has an effect when
153
+ `do_basic_tokenize=True`
154
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
155
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
156
+ token instead.
157
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
158
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
159
+ sequence classification or for a text and a question for question answering. It is also used as the last
160
+ token of a sequence built with special tokens.
161
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
162
+ The token used for padding, for example when batching sequences of different lengths.
163
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
164
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
165
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
166
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
167
+ The token used for masking values. This is the token used when training this model with masked language
168
+ modeling. This is the token which the model will try to predict.
169
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
170
+ Whether or not to tokenize Chinese characters.
171
+
172
+ This should likely be deactivated for Japanese (see this
173
+ [issue](https://github.com/huggingface/transformers/issues/328)).
174
+ strip_accents (`bool`, *optional*):
175
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
176
+ value for `lowercase` (as in the original BERT).
177
+ """
178
+
179
+ vocab_files_names = VOCAB_FILES_NAMES
180
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
181
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
182
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
183
+
184
+ def __init__(
185
+ self,
186
+ vocab_file,
187
+ do_lower_case=True,
188
+ do_basic_tokenize=True,
189
+ never_split=None,
190
+ unk_token="[UNK]",
191
+ sep_token="[SEP]",
192
+ pad_token="[PAD]",
193
+ cls_token="[CLS]",
194
+ mask_token="[MASK]",
195
+ tokenize_chinese_chars=True,
196
+ strip_accents=None,
197
+ **kwargs
198
+ ):
199
+ super().__init__(
200
+ do_lower_case=do_lower_case,
201
+ do_basic_tokenize=do_basic_tokenize,
202
+ never_split=never_split,
203
+ unk_token=unk_token,
204
+ sep_token=sep_token,
205
+ pad_token=pad_token,
206
+ cls_token=cls_token,
207
+ mask_token=mask_token,
208
+ tokenize_chinese_chars=tokenize_chinese_chars,
209
+ strip_accents=strip_accents,
210
+ **kwargs,
211
+ )
212
+
213
+ if not os.path.isfile(vocab_file):
214
+ raise ValueError(
215
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
216
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
217
+ )
218
+ self.vocab = load_vocab(vocab_file)
219
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
220
+ self.do_basic_tokenize = do_basic_tokenize
221
+ if do_basic_tokenize:
222
+ self.basic_tokenizer = BasicTokenizer(
223
+ do_lower_case=do_lower_case,
224
+ never_split=never_split,
225
+ tokenize_chinese_chars=tokenize_chinese_chars,
226
+ strip_accents=strip_accents,
227
+ )
228
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
229
+
230
+ @property
231
+ def do_lower_case(self):
232
+ return self.basic_tokenizer.do_lower_case
233
+
234
+ @property
235
+ def vocab_size(self):
236
+ return len(self.vocab)
237
+
238
+ def get_vocab(self):
239
+ return dict(self.vocab, **self.added_tokens_encoder)
240
+
241
+ def _tokenize(self, text):
242
+ split_tokens = []
243
+ if self.do_basic_tokenize:
244
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
245
+
246
+ # If the token is part of the never_split set
247
+ if token in self.basic_tokenizer.never_split:
248
+ split_tokens.append(token)
249
+ else:
250
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
251
+ else:
252
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
253
+ return split_tokens
254
+
255
+ def _convert_token_to_id(self, token):
256
+ """Converts a token (str) in an id using the vocab."""
257
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
258
+
259
+ def _convert_id_to_token(self, index):
260
+ """Converts an index (integer) in a token (str) using the vocab."""
261
+ return self.ids_to_tokens.get(index, self.unk_token)
262
+
263
+ def convert_tokens_to_string(self, tokens):
264
+ """Converts a sequence of tokens (string) in a single string."""
265
+ out_string = " ".join(tokens).replace(" ##", "").strip()
266
+ return out_string
267
+
268
+ def build_inputs_with_special_tokens(
269
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
270
+ ) -> List[int]:
271
+ """
272
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
273
+ adding special tokens. A BERT sequence has the following format:
274
+
275
+ - single sequence: `[CLS] X [SEP]`
276
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
277
+
278
+ Args:
279
+ token_ids_0 (`List[int]`):
280
+ List of IDs to which the special tokens will be added.
281
+ token_ids_1 (`List[int]`, *optional*):
282
+ Optional second list of IDs for sequence pairs.
283
+
284
+ Returns:
285
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
286
+ """
287
+ if token_ids_1 is None:
288
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
289
+ cls = [self.cls_token_id]
290
+ sep = [self.sep_token_id]
291
+ return cls + token_ids_0 + sep + token_ids_1 + sep
292
+
293
+ def get_special_tokens_mask(
294
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
295
+ ) -> List[int]:
296
+ """
297
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
298
+ special tokens using the tokenizer `prepare_for_model` method.
299
+
300
+ Args:
301
+ token_ids_0 (`List[int]`):
302
+ List of IDs.
303
+ token_ids_1 (`List[int]`, *optional*):
304
+ Optional second list of IDs for sequence pairs.
305
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
306
+ Whether or not the token list is already formatted with special tokens for the model.
307
+
308
+ Returns:
309
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
310
+ """
311
+
312
+ if already_has_special_tokens:
313
+ return super().get_special_tokens_mask(
314
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
315
+ )
316
+
317
+ if token_ids_1 is not None:
318
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
319
+ return [1] + ([0] * len(token_ids_0)) + [1]
320
+
321
+ def create_token_type_ids_from_sequences(
322
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
323
+ ) -> List[int]:
324
+ """
325
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
326
+ pair mask has the following format:
327
+
328
+ ```
329
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
330
+ | first sequence | second sequence |
331
+ ```
332
+
333
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
334
+
335
+ Args:
336
+ token_ids_0 (`List[int]`):
337
+ List of IDs.
338
+ token_ids_1 (`List[int]`, *optional*):
339
+ Optional second list of IDs for sequence pairs.
340
+
341
+ Returns:
342
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
343
+ """
344
+ sep = [self.sep_token_id]
345
+ cls = [self.cls_token_id]
346
+ if token_ids_1 is None:
347
+ return len(cls + token_ids_0 + sep) * [0]
348
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
349
+
350
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
351
+ index = 0
352
+ if os.path.isdir(save_directory):
353
+ vocab_file = os.path.join(
354
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
355
+ )
356
+ else:
357
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
358
+ with open(vocab_file, "w", encoding="utf-8") as writer:
359
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
360
+ if index != token_index:
361
+ logger.warning(
362
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
363
+ " Please check that the vocabulary is not corrupted!"
364
+ )
365
+ index = token_index
366
+ writer.write(token + "\n")
367
+ index += 1
368
+ return (vocab_file,)
369
+
370
+
371
+ class BasicTokenizer(object):
372
+ """
373
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
374
+
375
+ Args:
376
+ do_lower_case (`bool`, *optional*, defaults to `True`):
377
+ Whether or not to lowercase the input when tokenizing.
378
+ never_split (`Iterable`, *optional*):
379
+ Collection of tokens which will never be split during tokenization. Only has an effect when
380
+ `do_basic_tokenize=True`
381
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
382
+ Whether or not to tokenize Chinese characters.
383
+
384
+ This should likely be deactivated for Japanese (see this
385
+ [issue](https://github.com/huggingface/transformers/issues/328)).
386
+ strip_accents: (`bool`, *optional*):
387
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
388
+ value for `lowercase` (as in the original BERT).
389
+ """
390
+
391
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
392
+ if never_split is None:
393
+ never_split = []
394
+ self.do_lower_case = do_lower_case
395
+ self.never_split = set(never_split)
396
+ self.tokenize_chinese_chars = tokenize_chinese_chars
397
+ self.strip_accents = strip_accents
398
+
399
+ def tokenize(self, text, never_split=None):
400
+ """
401
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
402
+ WordPieceTokenizer.
403
+
404
+ Args:
405
+ never_split (`List[str]`, *optional*)
406
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
407
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
408
+ """
409
+ # union() returns a new set by concatenating the two sets.
410
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
411
+ text = self._clean_text(text)
412
+
413
+ # This was added on November 1st, 2018 for the multilingual and Chinese
414
+ # models. This is also applied to the English models now, but it doesn't
415
+ # matter since the English models were not trained on any Chinese data
416
+ # and generally don't have any Chinese data in them (there are Chinese
417
+ # characters in the vocabulary because Wikipedia does have some Chinese
418
+ # words in the English Wikipedia.).
419
+ if self.tokenize_chinese_chars:
420
+ text = self._tokenize_chinese_chars(text)
421
+ orig_tokens = whitespace_tokenize(text)
422
+ split_tokens = []
423
+ for token in orig_tokens:
424
+ if token not in never_split:
425
+ if self.do_lower_case:
426
+ token = token.lower()
427
+ if self.strip_accents is not False:
428
+ token = self._run_strip_accents(token)
429
+ elif self.strip_accents:
430
+ token = self._run_strip_accents(token)
431
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
432
+
433
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
434
+ return output_tokens
435
+
436
+ def _run_strip_accents(self, text):
437
+ """Strips accents from a piece of text."""
438
+ text = unicodedata.normalize("NFD", text)
439
+ output = []
440
+ for char in text:
441
+ cat = unicodedata.category(char)
442
+ if cat == "Mn":
443
+ continue
444
+ output.append(char)
445
+ return "".join(output)
446
+
447
+ def _run_split_on_punc(self, text, never_split=None):
448
+ """Splits punctuation on a piece of text."""
449
+ if never_split is not None and text in never_split:
450
+ return [text]
451
+ chars = list(text)
452
+ i = 0
453
+ start_new_word = True
454
+ output = []
455
+ while i < len(chars):
456
+ char = chars[i]
457
+ if _is_punctuation(char):
458
+ output.append([char])
459
+ start_new_word = True
460
+ else:
461
+ if start_new_word:
462
+ output.append([])
463
+ start_new_word = False
464
+ output[-1].append(char)
465
+ i += 1
466
+
467
+ return ["".join(x) for x in output]
468
+
469
+ def _tokenize_chinese_chars(self, text):
470
+ """Adds whitespace around any CJK character."""
471
+ output = []
472
+ for char in text:
473
+ cp = ord(char)
474
+ if self._is_chinese_char(cp):
475
+ output.append(" ")
476
+ output.append(char)
477
+ output.append(" ")
478
+ else:
479
+ output.append(char)
480
+ return "".join(output)
481
+
482
+ def _is_chinese_char(self, cp):
483
+ """Checks whether CP is the codepoint of a CJK character."""
484
+ # This defines a "chinese character" as anything in the CJK Unicode block:
485
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
486
+ #
487
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
488
+ # despite its name. The modern Korean Hangul alphabet is a different block,
489
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
490
+ # space-separated words, so they are not treated specially and handled
491
+ # like the all of the other languages.
492
+ if (
493
+ (cp >= 0x4E00 and cp <= 0x9FFF)
494
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
495
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
496
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
497
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
498
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
499
+ or (cp >= 0xF900 and cp <= 0xFAFF)
500
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
501
+ ): #
502
+ return True
503
+
504
+ return False
505
+
506
+ def _clean_text(self, text):
507
+ """Performs invalid character removal and whitespace cleanup on text."""
508
+ output = []
509
+ for char in text:
510
+ cp = ord(char)
511
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
512
+ continue
513
+ if _is_whitespace(char):
514
+ output.append(" ")
515
+ else:
516
+ output.append(char)
517
+ return "".join(output)
518
+
519
+
520
+ class WordpieceTokenizer(object):
521
+ """Runs WordPiece tokenization."""
522
+
523
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
524
+ self.vocab = vocab
525
+ self.unk_token = unk_token
526
+ self.max_input_chars_per_word = max_input_chars_per_word
527
+
528
+ def tokenize(self, text):
529
+ """
530
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
531
+ tokenization using the given vocabulary.
532
+
533
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
534
+
535
+ Args:
536
+ text: A single token or whitespace separated tokens. This should have
537
+ already been passed through *BasicTokenizer*.
538
+
539
+ Returns:
540
+ A list of wordpiece tokens.
541
+ """
542
+
543
+ output_tokens = []
544
+ for token in whitespace_tokenize(text):
545
+ chars = list(token)
546
+ if len(chars) > self.max_input_chars_per_word:
547
+ output_tokens.append(self.unk_token)
548
+ continue
549
+
550
+ is_bad = False
551
+ start = 0
552
+ sub_tokens = []
553
+ while start < len(chars):
554
+ end = len(chars)
555
+ cur_substr = None
556
+ while start < end:
557
+ substr = "".join(chars[start:end])
558
+ if start > 0:
559
+ substr = "##" + substr
560
+ if substr in self.vocab:
561
+ cur_substr = substr
562
+ break
563
+ end -= 1
564
+ if cur_substr is None:
565
+ is_bad = True
566
+ break
567
+ sub_tokens.append(cur_substr)
568
+ start = end
569
+
570
+ if is_bad:
571
+ output_tokens.append(self.unk_token)
572
+ else:
573
+ output_tokens.extend(sub_tokens)
574
+ return output_tokens
KTeleBERT/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.12.2
2
+ tqdm
3
+ torch
4
+ ltp
5
+ ltp-core
6
+ ltp-extension
7
+ cycle
8
+ torch>=1.10.0
9
+ easydict
10
+ re
KTeleBERT/run.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m torch.distributed.launch --nproc_per_node=4 main.py --LLRD 1 \
2
+ --eval_step 10 \
3
+ --save_model 1 \
4
+ --mask_stratege wwm \
5
+ --batch_size 64 \
6
+ --batch_size_ke 64 \
7
+ --exp_name Fine_tune_2 \
8
+ --exp_id v01 \
9
+ --workers 8 \
10
+ --use_NumEmb 1 \
11
+ --seq_data_name Seq_data_RuAlmEntKpiTbwDoc \
12
+ --maxlength 256 \
13
+ --lr 4e-5 \
14
+ --ke_lr 8e-5 \
15
+ --train_strategy 2 \
16
+ --model_name TeleBert2 \
17
+ --train_ratio 1 \
18
+ --save_pretrain 0 \
19
+ --dist 1 \
20
+ --accumulation_steps 8 \
21
+ --accumulation_steps_ke 6 \
22
+ --special_token_mask 0 \
23
+ --freeze_layer 0 \
24
+ --ernie_stratege -1 \
25
+ --mlm_probability_increase curve \
26
+ --use_kpi_loss 1 \
27
+ --mlm_probability 0.4 \
28
+ --use_awl 1 \
29
+ --cls_head_init 1 \
30
+ --emb_init 0 \
31
+ --final_mlm_probability 0.4 \
32
+ --ke_dim 256 \
33
+ --plm_emb_type cls \
34
+ --train_together 0 \
35
+
KTeleBERT/run_get_ref.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python get_chinese_ref.py --batch_size 50 \
2
+ --deal_numeric 1 \
3
+ --seq_data_name Seq_data_large \
4
+ --read_cws 0 \
5
+ # --seq_data_name Seq_data_base \
6
+ # --read_cws 1 \
7
+
8
+ # python get_chinese_ref.py --batch_size 150
9
+
10
+ # python get_chinese_ref.py --batch_size 200
11
+
12
+ # python get_chinese_ref.py --batch_size 250
13
+
14
+ # python get_chinese_ref.py --batch_size 300
15
+
16
+ # python main.py --LLRD 1 \
17
+ # --eval_step 10 \
18
+ # --epoch 20 \
19
+ # --save_model 1 \
20
+ # --mask_stratege wwm \
21
+ # --batch_size 50 \
22
+ # --use_NumEmb 1 \
KTeleBERT/special_token_pre_emb.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils import add_special_token
2
+ import os.path as osp
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from easydict import EasyDict as edict
7
+ import argparse
8
+ import pdb
9
+ import json
10
+ from model import BertTokenizer
11
+ from collections import Counter
12
+ from tqdm import tqdm
13
+ from time import time
14
+ from numpy import mean
15
+ import math
16
+
17
+ from transformers import BertModel
18
+
19
+
20
+ class cfg():
21
+ def __init__(self):
22
+ self.this_dir = osp.dirname(__file__)
23
+ # change
24
+ self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', ''))
25
+
26
+ def get_args(self):
27
+ parser = argparse.ArgumentParser()
28
+ # seq_data_name = "Seq_data_tiny_831"
29
+ parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path")
30
+ parser.add_argument("--update_model_name", default='MacBert', type=str, help="MacBert")
31
+ parser.add_argument("--pretrained_model_name", default='TeleBert', type=str, help="TeleBert")
32
+ parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件")
33
+ self.cfg = parser.parse_args()
34
+
35
+ def update_train_configs(self):
36
+ # TODO: update some dynamic variable
37
+ self.cfg.data_root = self.data_root
38
+ self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)
39
+
40
+ return self.cfg
41
+
42
+
43
+ if __name__ == '__main__':
44
+ '''
45
+ 功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据)
46
+ '''
47
+ cfg = cfg()
48
+ cfg.get_args()
49
+ cfgs = cfg.update_train_configs()
50
+
51
+ # 用来被更新的,需要添加token的tokenizer
52
+ path = osp.join(cfgs.data_root, 'transformer', cfgs.update_model_name)
53
+ assert osp.exists(path)
54
+ tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True)
55
+ tokenizer, special_token, norm_token = add_special_token(tokenizer)
56
+ added_vocab = tokenizer.get_added_vocab()
57
+ vocb_path = osp.join(cfgs.data_path, 'added_vocab.json')
58
+
59
+ with open(vocb_path, 'w') as fp:
60
+ json.dump(added_vocab, fp, ensure_ascii=False)
61
+
62
+ vocb_description = osp.join(cfgs.data_path, 'vocab_descrip.json')
63
+ vocb_descrip = None
64
+
65
+ vocb_descrip = {
66
+ "alm": "alarm",
67
+ "ran": "ran 无线接入网",
68
+ "mml": "MML 人机语言命令",
69
+ "nf": "NF 独立网络服务",
70
+ "apn": "APN 接入点名称",
71
+ "pgw": "PGW 数据管理子系统模块",
72
+ "lst": "LST 查询命令",
73
+ "qos": "QoS 定制服务质量",
74
+ "ipv": "IPV 互联网通讯协议版本",
75
+ "ims": "IMS IP多模态子系统",
76
+ "gtp": "GTP GPRS隧道协议",
77
+ "pdp": "PDP 分组数据协议",
78
+ "hss": "HSS HTTP Smooth Stream",
79
+ "[ALM]": "alarm 告警 标记",
80
+ "[KPI]": "kpi 关键性能指标 标记",
81
+ "[LOC]": "location 事件发生位置 标记",
82
+ "[EOS]": "end of the sentence 文档结尾 标记",
83
+ "[ENT]": "实体标记",
84
+ "[ATTR]": "属性标记",
85
+ "[NUM]": "数值标记",
86
+ "[REL]": "关系标记",
87
+ "[DOC]": "文档标记"
88
+ }
89
+
90
+ # if osp.exists(vocb_description):
91
+ # with open(vocb_description, 'r') as fp:
92
+ # vocb_descrip = json.load(added_vocab)
93
+
94
+ # 用来进行embedding的模型
95
+ path = osp.join(cfgs.data_root, 'transformer', cfgs.pretrained_model_name)
96
+ assert osp.exists(path)
97
+ pre_tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True)
98
+ model = BertModel.from_pretrained(path)
99
+
100
+ print("use the vocb_description")
101
+ key_to_emb = {}
102
+ for key in added_vocab.keys():
103
+ if vocb_description is not None:
104
+ if key in vocb_description:
105
+ # 一部分需要描述
106
+ key_tokens = pre_tokenizer(vocb_description[key], return_tensors='pt')
107
+ else:
108
+ key_tokens = pre_tokenizer(key, return_tensors='pt')
109
+ else:
110
+ key_tokens = pre_tokenizer(key, return_tensors='pt')
111
+
112
+ hidden_state = model(**key_tokens, output_hidden_states=True).hidden_states
113
+ pdb.set_trace()
114
+ key_to_emb[key] = hidden_state[-1][:, 1:-1, :].mean(dim=1)
115
+
116
+ emb_path = osp.join(cfgs.data_path, 'added_vocab_embedding.pt')
117
+
118
+ torch.save(key_to_emb, emb_path)
119
+ print(f'save to {emb_path}')
KTeleBERT/src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
KTeleBERT/src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (137 Bytes). View file
 
KTeleBERT/src/__pycache__/data.cpython-38.pyc ADDED
Binary file (16 kB). View file
 
KTeleBERT/src/__pycache__/distributed_utils.cpython-38.pyc ADDED
Binary file (2.15 kB). View file
 
KTeleBERT/src/__pycache__/utils.cpython-38.pyc ADDED
Binary file (16 kB). View file
 
KTeleBERT/src/data.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import json
4
+ import numpy as np
5
+ import pdb
6
+ import os.path as osp
7
+ from model import BertTokenizer
8
+ import torch.distributed as dist
9
+
10
+
11
+ class SeqDataset(torch.utils.data.Dataset):
12
+ def __init__(self, data, chi_ref=None, kpi_ref=None):
13
+ self.data = data
14
+ self.chi_ref = chi_ref
15
+ self.kpi_ref = kpi_ref
16
+
17
+ def __len__(self):
18
+ return len(self.data)
19
+
20
+ def __getitem__(self, index):
21
+ sample = self.data[index]
22
+ if self.chi_ref is not None:
23
+ chi_ref = self.chi_ref[index]
24
+ else:
25
+ chi_ref = None
26
+
27
+ if self.kpi_ref is not None:
28
+ kpi_ref = self.kpi_ref[index]
29
+ else:
30
+ kpi_ref = None
31
+
32
+ return sample, chi_ref, kpi_ref
33
+
34
+
35
+ class OrderDataset(torch.utils.data.Dataset):
36
+ def __init__(self, data, kpi_ref=None):
37
+ self.data = data
38
+ self.kpi_ref = kpi_ref
39
+
40
+ def __len__(self):
41
+ return len(self.data)
42
+
43
+ def __getitem__(self, index):
44
+ sample = self.data[index]
45
+ if self.kpi_ref is not None:
46
+ kpi_ref = self.kpi_ref[index]
47
+ else:
48
+ kpi_ref = None
49
+
50
+ return sample, kpi_ref
51
+
52
+
53
+ class KGDataset(torch.utils.data.Dataset):
54
+ def __init__(self, data):
55
+ self.data = data
56
+ self.len = len(self.data)
57
+
58
+ def __len__(self):
59
+ return self.len
60
+
61
+ def __getitem__(self, index):
62
+
63
+ sample = self.data[index]
64
+ return sample
65
+
66
+ # TODO: 重构 DataCollatorForLanguageModeling
67
+
68
+
69
+ class Collator_base(object):
70
+ # TODO: 定义 collator,模仿Lako
71
+ # 完成mask,padding
72
+ def __init__(self, args, tokenizer, special_token=None):
73
+ self.tokenizer = tokenizer
74
+ if special_token is None:
75
+ self.special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]']
76
+ else:
77
+ self.special_token = special_token
78
+
79
+ self.text_maxlength = args.maxlength
80
+ self.mlm_probability = args.mlm_probability
81
+ self.args = args
82
+ if self.args.special_token_mask:
83
+ self.special_token = ['|', '[NUM]']
84
+
85
+ if not self.args.only_test and self.args.use_mlm_task:
86
+ if args.mask_stratege == 'rand':
87
+ self.mask_func = self.torch_mask_tokens
88
+ else:
89
+ if args.mask_stratege == 'wwm':
90
+ # 必须使用special_word, 因为这里的wwm基于分词
91
+ if args.rank == 0:
92
+ print("use word-level Mask ...")
93
+ assert args.add_special_word == 1
94
+ self.mask_func = self.wwm_mask_tokens
95
+ else: # domain
96
+ if args.rank == 0:
97
+ print("use token-level Mask ...")
98
+ self.mask_func = self.domain_mask_tokens
99
+
100
+ def __call__(self, batch):
101
+ # 把 batch 中的数值提取出,用specail token 替换
102
+ # 把数值信息,以及数值的位置信息单独通过list传进去
103
+ # 后面训练的阶段直接把数值插入embedding的位置
104
+ # 数值不参与 mask
105
+ # wwm的时候可以把chinese ref 随batch一起输入
106
+ kpi_ref = None
107
+ if self.args.use_NumEmb:
108
+ kpi_ref = [item[2] for item in batch]
109
+ # if self.args.mask_stratege != 'rand':
110
+ chinese_ref = [item[1] for item in batch]
111
+ batch = [item[0] for item in batch]
112
+ # 此时batch不止有字符串
113
+ batch = self.tokenizer.batch_encode_plus(
114
+ batch,
115
+ padding='max_length',
116
+ max_length=self.text_maxlength,
117
+ truncation=True,
118
+ return_tensors="pt",
119
+ return_token_type_ids=False,
120
+ return_attention_mask=True,
121
+ add_special_tokens=False
122
+ )
123
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
124
+ # self.torch_mask_tokens
125
+
126
+ # if batch["input_ids"].shape[1] != 128:
127
+ # pdb.set_trace()
128
+ if chinese_ref is not None:
129
+ batch["chinese_ref"] = chinese_ref
130
+ if kpi_ref is not None:
131
+ batch["kpi_ref"] = kpi_ref
132
+
133
+ # 训练需要 mask
134
+
135
+ if not self.args.only_test and self.args.use_mlm_task:
136
+ batch["input_ids"], batch["labels"] = self.mask_func(
137
+ batch, special_tokens_mask=special_tokens_mask
138
+ )
139
+ else:
140
+ # 非训练状态
141
+ # 且不用MLM进行训练
142
+ labels = batch["input_ids"].clone()
143
+ if self.tokenizer.pad_token_id is not None:
144
+ labels[labels == self.tokenizer.pad_token_id] = -100
145
+ batch["labels"] = labels
146
+
147
+ return batch
148
+
149
+ def torch_mask_tokens(self, inputs, special_tokens_mask=None):
150
+ """
151
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
152
+ """
153
+ if "input_ids" in inputs:
154
+ inputs = inputs["input_ids"]
155
+ labels = inputs.clone()
156
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
157
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
158
+ if special_tokens_mask is None:
159
+ special_tokens_mask = [
160
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
161
+ ]
162
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
163
+ else:
164
+ special_tokens_mask = special_tokens_mask.bool()
165
+ # pdb.set_trace()
166
+
167
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
168
+ masked_indices = torch.bernoulli(probability_matrix).bool()
169
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
170
+
171
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
172
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
173
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
174
+
175
+ # 10% of the time, we replace masked input tokens with random word
176
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
177
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
178
+ inputs[indices_random] = random_words[indices_random]
179
+
180
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
181
+ return inputs, labels
182
+
183
+ def wwm_mask_tokens(self, inputs, special_tokens_mask=None):
184
+ mask_labels = []
185
+ ref_tokens = inputs["chinese_ref"]
186
+ input_ids = inputs["input_ids"]
187
+ sz = len(input_ids)
188
+
189
+ # 把input id 先恢复到token
190
+ for i in range(sz):
191
+ # 这里的主体是读入的ref,但是可能存在max_len不统一的情况
192
+ mask_labels.append(self._whole_word_mask(ref_tokens[i]))
193
+
194
+ batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, self.text_maxlength, pad_to_multiple_of=None)
195
+ inputs, labels = self.torch_mask_tokens_4wwm(input_ids, batch_mask)
196
+ return inputs, labels
197
+
198
+ # input_tokens: List[str]
199
+ def _whole_word_mask(self, input_tokens, max_predictions=512):
200
+ """
201
+ Get 0/1 labels for masked tokens with whole word mask proxy
202
+ """
203
+ assert isinstance(self.tokenizer, (BertTokenizer))
204
+ # 输入是 [..., ..., ..., ...] 格式
205
+ cand_indexes = []
206
+ cand_token = []
207
+
208
+ for i, token in enumerate(input_tokens):
209
+ if i >= self.text_maxlength - 1:
210
+ # 不能超过最大值,截断一下
211
+ break
212
+ if token.lower() in self.special_token:
213
+ # special token 的词不应该被mask
214
+ continue
215
+ if len(cand_indexes) >= 1 and token.startswith("##"):
216
+ cand_indexes[-1].append(i)
217
+ cand_token.append(i)
218
+ else:
219
+ cand_indexes.append([i])
220
+ cand_token.append(i)
221
+
222
+ random.shuffle(cand_indexes)
223
+ # 原来是:input_tokens
224
+ # 但是这里的特殊token很多,因此提前去掉了特殊token
225
+ # 这里的15%是去掉了特殊token的15%。+2的原因是把CLS SEP两个 flag的长度加上
226
+ num_to_predict = min(max_predictions, max(1, int(round((len(cand_token) + 2) * self.mlm_probability))))
227
+ masked_lms = []
228
+ covered_indexes = set()
229
+ for index_set in cand_indexes:
230
+ # 到达长度了结束
231
+ if len(masked_lms) >= num_to_predict:
232
+ break
233
+ # If adding a whole-word mask would exceed the maximum number of
234
+ # predictions, then just skip this candidate.
235
+ # 不能让其长度大于15%,最多等于
236
+ if len(masked_lms) + len(index_set) > num_to_predict:
237
+ continue
238
+ is_any_index_covered = False
239
+ for index in index_set:
240
+ # 不考虑重叠的token进行mask
241
+ if index in covered_indexes:
242
+ is_any_index_covered = True
243
+ break
244
+ if is_any_index_covered:
245
+ continue
246
+ for index in index_set:
247
+ covered_indexes.add(index)
248
+ masked_lms.append(index)
249
+
250
+ if len(covered_indexes) != len(masked_lms):
251
+ # 一般不会出现,因为过程中避免重复了
252
+ raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
253
+ # 不能超过最大值,截断
254
+ mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_tokens), self.text_maxlength))]
255
+
256
+ return mask_labels
257
+
258
+ # 确定这里面需要mask的:置0/1
259
+
260
+ # 调用 self.torch_mask_tokens
261
+
262
+ #
263
+ pass
264
+
265
+ def torch_mask_tokens_4wwm(self, inputs, mask_labels):
266
+ """
267
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
268
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
269
+ """
270
+ # if "input_ids" in inputs:
271
+ # inputs = inputs["input_ids"]
272
+ if self.tokenizer.mask_token is None:
273
+ raise ValueError(
274
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
275
+ " --mlm flag if you want to use this tokenizer."
276
+ )
277
+ labels = inputs.clone()
278
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
279
+
280
+ probability_matrix = mask_labels
281
+
282
+ special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
283
+
284
+ if len(special_tokens_mask[0]) != probability_matrix.shape[1]:
285
+ print(f"len(special_tokens_mask[0]): {len(special_tokens_mask[0])}")
286
+ print(f"probability_matrix.shape[1]): {probability_matrix.shape[1]}")
287
+ print(f'max len {self.text_maxlength}')
288
+ print(f"pad_token_id: {self.tokenizer.pad_token_id}")
289
+ # if self.args.rank != in_rank:
290
+ if self.args.dist:
291
+ dist.barrier()
292
+ pdb.set_trace()
293
+ else:
294
+ pdb.set_trace()
295
+
296
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
297
+ if self.tokenizer._pad_token is not None:
298
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
299
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
300
+
301
+ masked_indices = probability_matrix.bool()
302
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
303
+
304
+ # 这里的wwm,每次 mask/替换/不变的时候单位不是一体的,会拆开
305
+ # 其实不太合理,但是也没办法
306
+
307
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
308
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
309
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
310
+
311
+ # 10% of the time, we replace masked input tokens with random word
312
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
313
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
314
+ inputs[indices_random] = random_words[indices_random]
315
+
316
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
317
+ return inputs, labels
318
+
319
+ # TODO: 按区域cell 进行mask
320
+
321
+ def domain_mask_tokens(self, inputs, special_tokens_mask=None):
322
+ pass
323
+
324
+
325
+ class Collator_kg(object):
326
+ # TODO: 定义 collator,模仿Lako
327
+ # 完成 随机减少一部分属性
328
+ def __init__(self, args, tokenizer, data):
329
+ self.tokenizer = tokenizer
330
+ self.text_maxlength = args.maxlength
331
+ self.cross_sampling_flag = 0
332
+ # ke 的bs 是正常bs的四分之一
333
+ self.neg_num = args.neg_num
334
+ # 负样本不能在全集中
335
+ self.data = data
336
+ self.args = args
337
+
338
+ def __call__(self, batch):
339
+ # 先编码成可token形式避免重复编码
340
+ outputs = self.sampling(batch)
341
+
342
+ return outputs
343
+
344
+ def sampling(self, data):
345
+ """Filtering out positive samples and selecting some samples randomly as negative samples.
346
+
347
+ Args:
348
+ data: The triples used to be sampled.
349
+
350
+ Returns:
351
+ batch_data: The training data.
352
+ """
353
+ batch_data = {}
354
+ neg_ent_sample = []
355
+
356
+ self.cross_sampling_flag = 1 - self.cross_sampling_flag
357
+
358
+ head_list = []
359
+ rel_list = []
360
+ tail_list = []
361
+ # pdb.set_trace()
362
+ if self.cross_sampling_flag == 0:
363
+ batch_data['mode'] = "head-batch"
364
+ for index, (head, relation, tail) in enumerate(data):
365
+ # in batch negative
366
+ neg_head = self.find_neghead(data, index, relation, tail)
367
+ neg_ent_sample.extend(random.sample(neg_head, self.neg_num))
368
+ head_list.append(head)
369
+ rel_list.append(relation)
370
+ tail_list.append(tail)
371
+ else:
372
+ batch_data['mode'] = "tail-batch"
373
+ for index, (head, relation, tail) in enumerate(data):
374
+ neg_tail = self.find_negtail(data, index, relation, head)
375
+ neg_ent_sample.extend(random.sample(neg_tail, self.neg_num))
376
+
377
+ head_list.append(head)
378
+ rel_list.append(relation)
379
+ tail_list.append(tail)
380
+
381
+ neg_ent_batch = self.batch_tokenizer(neg_ent_sample)
382
+ head_batch = self.batch_tokenizer(head_list)
383
+ rel_batch = self.batch_tokenizer(rel_list)
384
+ tail_batch = self.batch_tokenizer(tail_list)
385
+
386
+ ent_list = head_list + rel_list + tail_list
387
+ ent_dict = {k: v for v, k in enumerate(ent_list)}
388
+ # 用来索引负样本
389
+ neg_index = torch.tensor([ent_dict[i] for i in neg_ent_sample])
390
+ # pos_head_index = torch.tensor(list(range(len(head_list)))
391
+
392
+ batch_data["positive_sample"] = (head_batch, rel_batch, tail_batch)
393
+ batch_data['negative_sample'] = neg_ent_batch
394
+ batch_data['neg_index'] = neg_index
395
+ return batch_data
396
+
397
+ def batch_tokenizer(self, input_list):
398
+ return self.tokenizer.batch_encode_plus(
399
+ input_list,
400
+ padding='max_length',
401
+ max_length=self.text_maxlength,
402
+ truncation=True,
403
+ return_tensors="pt",
404
+ return_token_type_ids=False,
405
+ return_attention_mask=True,
406
+ add_special_tokens=False
407
+ )
408
+
409
+ def find_neghead(self, data, index, rel, ta):
410
+ head_list = []
411
+ for i, (head, relation, tail) in enumerate(data):
412
+ # 负样本不能被包含
413
+ if i != index and [head, rel, ta] not in self.data:
414
+ head_list.append(head)
415
+ # 可能存在负样本不够的情况
416
+ # 自补齐
417
+ while len(head_list) < self.neg_num:
418
+ head_list.extend(random.sample(head_list, min(self.neg_num - len(head_list), len(head_list))))
419
+
420
+ return head_list
421
+
422
+ def find_negtail(self, data, index, rel, he):
423
+ tail_list = []
424
+ for i, (head, relation, tail) in enumerate(data):
425
+ if i != index and [he, rel, tail] not in self.data:
426
+ tail_list.append(tail)
427
+ # 可能存在负样本不够的情况
428
+ # 自补齐
429
+ while len(tail_list) < self.neg_num:
430
+ tail_list.extend(random.sample(tail_list, min(self.neg_num - len(tail_list), len(tail_list))))
431
+ return tail_list
432
+
433
+ # 载入mask loss部分的数据
434
+
435
+
436
+ def load_data(logger, args):
437
+
438
+ data_path = args.data_path
439
+
440
+ data_name = args.seq_data_name
441
+ with open(osp.join(data_path, f'{data_name}_cws.json'), "r") as fp:
442
+ data = json.load(fp)
443
+ if args.rank == 0:
444
+ logger.info(f"[Start] Loading Seq dataset: [{len(data)}]...")
445
+ random.shuffle(data)
446
+
447
+ # data = data[:10000]
448
+ # pdb.set_trace()
449
+ train_test_split = int(args.train_ratio * len(data))
450
+ # random.shuffle(x)
451
+ # 训练/测试期间不应该打乱
452
+ train_data = data[0: train_test_split]
453
+ test_data = data[train_test_split: len(data)]
454
+
455
+ # 测试的时候也可能用到其实 not args.only_test
456
+ if args.use_mlm_task:
457
+ # if args.mask_stratege != 'rand':
458
+ # 读领域词汇
459
+ if args.rank == 0:
460
+ print("using the domain words .....")
461
+ domain_file_path = osp.join(args.data_path, f'{data_name}_chinese_ref.json')
462
+ with open(domain_file_path, 'r') as f:
463
+ chinese_ref = json.load(f)
464
+ # train_test_split=len(data)
465
+ chi_ref_train = chinese_ref[:train_test_split]
466
+ chi_ref_eval = chinese_ref[train_test_split:]
467
+ else:
468
+ chi_ref_train = None
469
+ chi_ref_eval = None
470
+
471
+ if args.use_NumEmb:
472
+ if args.rank == 0:
473
+ print("using the kpi and num .....")
474
+
475
+ kpi_file_path = osp.join(args.data_path, f'{data_name}_kpi_ref.json')
476
+ with open(kpi_file_path, 'r') as f:
477
+ kpi_ref = json.load(f)
478
+ kpi_ref_train = kpi_ref[:train_test_split]
479
+ kpi_ref_eval = kpi_ref[train_test_split:]
480
+ else:
481
+ # num_ref_train = None
482
+ # num_ref_eval = None
483
+ kpi_ref_train = None
484
+ kpi_ref_eval = None
485
+
486
+ # pdb.set_trace()
487
+ test_set = None
488
+ train_set = SeqDataset(train_data, chi_ref=chi_ref_train, kpi_ref=kpi_ref_train)
489
+ if len(test_data) > 0:
490
+ test_set = SeqDataset(test_data, chi_ref=chi_ref_eval, kpi_ref=kpi_ref_eval)
491
+ if args.rank == 0:
492
+ logger.info("[End] Loading Seq dataset...")
493
+ return train_set, test_set, train_test_split
494
+
495
+ # 载入triple loss部分的数据
496
+
497
+
498
+ def load_data_kg(logger, args):
499
+ data_path = args.data_path
500
+ if args.rank == 0:
501
+ logger.info("[Start] Loading KG dataset...")
502
+ # # 三元组
503
+ # with open(osp.join(data_path, '5GC_KB/database_triples_831.json'), "r") as f:
504
+ # data = json.load(f)
505
+ # random.shuffle(data)
506
+
507
+ # # # TODO: triple loss这一块还没有测试集
508
+ # train_data = data[0:int(len(data)/args.batch_size)*args.batch_size]
509
+
510
+ # with open(osp.join(data_path, 'KG_data_tiny_831.json'),"w") as fp:
511
+ # json.dump(data[:1000], fp)
512
+ kg_data_name = args.kg_data_name
513
+ with open(osp.join(data_path, f'{kg_data_name}.json'), "r") as fp:
514
+ train_data = json.load(fp)
515
+ # pdb.set_trace()
516
+ # 124169
517
+ # 128482
518
+ # train_data = train_data[:124168]
519
+ # train_data = train_data[:1000]
520
+ train_set = KGDataset(train_data)
521
+ if args.rank == 0:
522
+ logger.info("[End] Loading KG dataset...")
523
+ return train_set, train_data
524
+
525
+
526
+ def _torch_collate_batch(examples, tokenizer, max_length=None, pad_to_multiple_of=None):
527
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
528
+ import numpy as np
529
+ import torch
530
+
531
+ # Tensorize if necessary.
532
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
533
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
534
+
535
+ length_of_first = examples[0].size(0)
536
+
537
+ # Check if padding is necessary.
538
+
539
+ # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
540
+ # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
541
+ # return torch.stack(examples, dim=0)
542
+
543
+ # If yes, check if we have a `pad_token`.
544
+ if tokenizer._pad_token is None:
545
+ raise ValueError(
546
+ "You are attempting to pad samples but the tokenizer you are using"
547
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
548
+ )
549
+
550
+ # Creating the full tensor and filling it with our data.
551
+
552
+ if max_length is None:
553
+ pdb.set_trace()
554
+ max_length = max(x.size(0) for x in examples)
555
+
556
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
557
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
558
+ result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
559
+ for i, example in enumerate(examples):
560
+ if tokenizer.padding_side == "right":
561
+ result[i, : example.shape[0]] = example
562
+ else:
563
+ result[i, -example.shape[0]:] = example
564
+
565
+ return result
566
+
567
+
568
+ def load_order_data(logger, args):
569
+ if args.rank == 0:
570
+ logger.info("[Start] Loading Order dataset...")
571
+
572
+ data_path = args.data_path
573
+ if len(args.order_test_name) > 0:
574
+ data_name = args.order_test_name
575
+ else:
576
+ data_name = args.order_data_name
577
+ tmp = osp.join(data_path, f'{data_name}.json')
578
+ if osp.exists(tmp):
579
+ dp = tmp
580
+ else:
581
+ dp = osp.join(data_path, 'downstream_task', f'{data_name}.json')
582
+ assert osp.exists(dp)
583
+ with open(dp, "r") as fp:
584
+ data = json.load(fp)
585
+ # data = data[:2000]
586
+ # pdb.set_trace()
587
+ train_test_split = int(args.train_ratio * len(data))
588
+
589
+ mid_split = int(train_test_split / 2)
590
+ mid = int(len(data) / 2)
591
+ # random.shuffle(x)
592
+ # 训练/测试期间不应该打乱
593
+ # train_data = data[0: train_test_split]
594
+ # test_data = data[train_test_split: len(data)]
595
+
596
+ # test_data = data[0: train_test_split]
597
+ # train_data = data[train_test_split: len(data)]
598
+
599
+ # 特殊分类 默认前一半和后一半对称
600
+ test_data = data[0: mid_split] + data[mid: mid + mid_split]
601
+ train_data = data[mid_split: mid] + data[mid + mid_split: len(data)]
602
+
603
+ # pdb.set_trace()
604
+ test_set = None
605
+ train_set = OrderDataset(train_data)
606
+ if len(test_data) > 0:
607
+ test_set = OrderDataset(test_data)
608
+ if args.rank == 0:
609
+ logger.info("[End] Loading Order dataset...")
610
+ return train_set, test_set, train_test_split
611
+
612
+
613
+ class Collator_order(object):
614
+ # 输入一个batch的数据,合并order后面再解耦
615
+ def __init__(self, args, tokenizer):
616
+ self.tokenizer = tokenizer
617
+ self.text_maxlength = args.maxlength
618
+ self.args = args
619
+ # 每一个pair中包含的数据数量
620
+ self.order_num = args.order_num
621
+ self.p_label, self.n_label = smooth_BCE(args.eps)
622
+
623
+ def __call__(self, batch):
624
+ # 输入数据按顺序堆叠, 间隔拆分
625
+ #
626
+ # 编码然后输出
627
+ output = []
628
+ for item in range(self.order_num):
629
+ output.extend([dat[0][0][item] for dat in batch])
630
+ # label smoothing
631
+
632
+ labels = [1 if dat[0][1][0] == 2 else self.p_label if dat[0][1][0] == 1 else self.n_label for dat in batch]
633
+ batch = self.tokenizer.batch_encode_plus(
634
+ output,
635
+ padding='max_length',
636
+ max_length=self.text_maxlength,
637
+ truncation=True,
638
+ return_tensors="pt",
639
+ return_token_type_ids=False,
640
+ return_attention_mask=True,
641
+ add_special_tokens=False
642
+ )
643
+ # torch.tensor()
644
+ return batch, torch.FloatTensor(labels)
645
+
646
+
647
+ def smooth_BCE(eps=0.1): # eps 平滑系数 [0, 1] => [0.95, 0.05]
648
+ # return positive, negative label smoothing BCE targets
649
+ # positive label= y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
650
+ # y_true=1 label_smoothing=eps=0.1
651
+ return 1.0 - 0.5 * eps, 0.5 * eps
KTeleBERT/src/distributed_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import pdb
6
+
7
+
8
+ def dist_pdb(rank, in_rank=0):
9
+ if rank != in_rank:
10
+ dist.barrier()
11
+ else:
12
+ pdb.set_trace()
13
+ dist.barrier()
14
+
15
+
16
+ def init_distributed_mode(args):
17
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
18
+ args.rank = int(os.environ["RANK"])
19
+ args.world_size = int(os.environ['WORLD_SIZE'])
20
+ args.gpu = int(os.environ['LOCAL_RANK'])
21
+ elif 'SLURM_PROCID' in os.environ:
22
+ args.rank = int(os.environ['SLURM_PROCID'])
23
+ args.gpu = args.rank % torch.cuda.device_count()
24
+ else:
25
+ print('Not using distributed mode')
26
+ args.distributed = False
27
+ return
28
+
29
+ args.distributed = True
30
+
31
+ torch.cuda.set_device(args.gpu)
32
+ args.dist_backend = 'nccl' # 通信后端,nvidia GPU推荐使用NCCL
33
+ print('| distributed init (rank {}): {}'.format(
34
+ args.rank, args.dist_url), flush=True)
35
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
36
+ world_size=args.world_size, rank=args.rank)
37
+ dist.barrier()
38
+
39
+
40
+ def cleanup():
41
+ dist.destroy_process_group()
42
+
43
+
44
+ def is_dist_avail_and_initialized():
45
+ """检查是否支持分布式环境"""
46
+ if not dist.is_available():
47
+ return False
48
+ if not dist.is_initialized():
49
+ return False
50
+ return True
51
+
52
+
53
+ def get_world_size():
54
+ if not is_dist_avail_and_initialized():
55
+ return 1
56
+ return dist.get_world_size()
57
+
58
+
59
+ def get_rank():
60
+ if not is_dist_avail_and_initialized():
61
+ return 0
62
+ return dist.get_rank()
63
+
64
+
65
+ def is_main_process():
66
+ return get_rank() == 0
67
+
68
+
69
+ def reduce_value(value, average=True):
70
+ world_size = get_world_size()
71
+ if world_size < 2: # 单GPU的情况
72
+ return value
73
+
74
+ with torch.no_grad():
75
+ dist.all_reduce(value)
76
+ if average:
77
+ value /= world_size
78
+
79
+ return value
KTeleBERT/src/utils.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import errno
4
+ import torch
5
+ import sys
6
+ import logging
7
+ import json
8
+ from pathlib import Path
9
+ import torch.distributed as dist
10
+ import csv
11
+ import os.path as osp
12
+ from time import time
13
+ from numpy import mean
14
+ import re
15
+ from transformers import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
16
+ import pdb
17
+ from torch import nn
18
+
19
+
20
+
21
+ # Huggingface的实现中,自带多种warmup策略
22
+ def set_optim(opt, model_list, freeze_part=[], accumulation_step=None):
23
+ # Bert optim
24
+ optimizer_list, scheduler_list, named_parameters = [], [], []
25
+ # cur_model = model.module if hasattr(model, 'module') else model
26
+ for model in model_list:
27
+ model_para = list(model.named_parameters())
28
+ model_para_train, freeze_layer = [], []
29
+ for n, p in model_para:
30
+ if not any(nd in n for nd in freeze_part):
31
+ model_para_train.append((n, p))
32
+ else:
33
+ p.requires_grad = False
34
+ freeze_layer.append((n, p))
35
+ named_parameters.extend(model_para_train)
36
+
37
+ # for name, param in model_list[0].named_parameters():
38
+ # if not param.requires_grad:
39
+ # print(name, param.size())
40
+
41
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
42
+ # numeric_model 也包括到这个部分中
43
+ ke_part = ['ke_model', 'loss_awl', 'numeric_model', 'order']
44
+ if opt.LLRD:
45
+ # 按层次衰减的学习率
46
+ all_name_orig = [n for n, p in named_parameters if not any(nd in n for nd in ke_part)]
47
+
48
+ opt_parameters, all_name = LLRD(opt, named_parameters, no_decay, ke_part)
49
+ remain = list(set(all_name_orig) - set(all_name))
50
+ remain_parameters = [
51
+ {'params': [p for n, p in named_parameters if not any(nd in n for nd in no_decay) and n in remain], "lr": opt.lr, 'weight_decay': opt.weight_decay},
52
+ {'params': [p for n, p in named_parameters if any(nd in n for nd in no_decay) and n in remain], "lr": opt.lr, 'weight_decay': 0.0}
53
+ ]
54
+ opt_parameters.extend(remain_parameters)
55
+ else:
56
+ opt_parameters = [
57
+ {'params': [p for n, p in named_parameters if not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)], "lr": opt.lr, 'weight_decay': opt.weight_decay},
58
+ {'params': [p for n, p in named_parameters if any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)], "lr": opt.lr, 'weight_decay': 0.0}
59
+ ]
60
+
61
+ ke_parameters = [
62
+ {'params': [p for n, p in named_parameters if not any(nd in n for nd in no_decay) and any(nd in n for nd in ke_part)], "lr": opt.ke_lr, 'weight_decay': opt.weight_decay},
63
+ {'params': [p for n, p in named_parameters if any(nd in n for nd in no_decay) and any(nd in n for nd in ke_part)], "lr": opt.ke_lr, 'weight_decay': 0.0}
64
+ ]
65
+ opt_parameters.extend(ke_parameters)
66
+ optimizer = AdamW(opt_parameters, lr=opt.lr, eps=opt.adam_epsilon)
67
+ if accumulation_step is None:
68
+ accumulation_step = opt.accumulation_steps
69
+ if opt.scheduler == 'linear':
70
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps/accumulation_step), num_training_steps=int(opt.total_steps/accumulation_step))
71
+ else:
72
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps/accumulation_step), num_training_steps=int(opt.total_steps/accumulation_step))
73
+
74
+ # ---- 判定所有参数是否被全部优化 ----
75
+ all_para_num = 0
76
+ for paras in opt_parameters:
77
+ all_para_num += len(paras['params'])
78
+ # pdb.set_trace()
79
+ assert len(named_parameters) == all_para_num
80
+ return optimizer, scheduler
81
+
82
+ # LLRD 学习率逐层衰减但
83
+
84
+ def LLRD(opt, named_parameters, no_decay, ke_part =[]):
85
+ opt_parameters = []
86
+ all_name = []
87
+ head_lr = opt.lr * 1.05
88
+ init_lr = opt.lr
89
+ lr = init_lr
90
+
91
+ # === Pooler and regressor ======================================================
92
+ params_0 = [p for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
93
+ and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
94
+ params_1 = [p for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
95
+ and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
96
+
97
+ name_0 = [n for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
98
+ and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
99
+ name_1 = [n for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
100
+ and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
101
+
102
+ all_name.extend(name_0)
103
+ all_name.extend(name_1)
104
+
105
+ head_params = {"params": params_0, "lr": head_lr, "weight_decay": 0.0}
106
+ opt_parameters.append(head_params)
107
+
108
+ head_params = {"params": params_1, "lr": head_lr, "weight_decay": 0.01}
109
+ opt_parameters.append(head_params)
110
+
111
+ # === 12 Hidden layers ==========================================================
112
+ for layer in range(11,-1,-1):
113
+ params_0 = [p for n,p in named_parameters if f"encoder.layer.{layer}." in n
114
+ and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
115
+ params_1 = [p for n,p in named_parameters if f"encoder.layer.{layer}." in n
116
+ and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
117
+
118
+ layer_params = {"params": params_0, "lr": lr, "weight_decay": 0.0}
119
+ opt_parameters.append(layer_params)
120
+
121
+ layer_params = {"params": params_1, "lr": lr, "weight_decay": 0.01}
122
+ opt_parameters.append(layer_params)
123
+
124
+ name_0 = [n for n,p in named_parameters if f"encoder.layer.{layer}." in n
125
+ and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
126
+ name_1 = [n for n,p in named_parameters if f"encoder.layer.{layer}." in n
127
+ and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
128
+ all_name.extend(name_0)
129
+ all_name.extend(name_1)
130
+
131
+ lr *= 0.95
132
+ # === Embeddings layer ==========================================================
133
+
134
+ params_0 = [p for n,p in named_parameters if ("embeddings" in n )
135
+ and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
136
+ params_1 = [p for n,p in named_parameters if ("embeddings" in n )
137
+ and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
138
+
139
+ embed_params = {"params": params_0, "lr": lr, "weight_decay": 0.0}
140
+ opt_parameters.append(embed_params)
141
+
142
+ embed_params = {"params": params_1, "lr": lr, "weight_decay": 0.01}
143
+ opt_parameters.append(embed_params)
144
+
145
+ name_0 = [n for n,p in named_parameters if ("embeddings" in n )
146
+ and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
147
+ name_1 = [n for n,p in named_parameters if ("embeddings" in n )
148
+ and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
149
+ all_name.extend(name_0)
150
+ all_name.extend(name_1)
151
+ return opt_parameters, all_name
152
+
153
+ class FixedScheduler(torch.optim.lr_scheduler.LambdaLR):
154
+ def __init__(self, optimizer, last_epoch=-1):
155
+ super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
156
+
157
+ def lr_lambda(self, step):
158
+ return 1.0
159
+
160
+
161
+ class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
162
+ def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, last_epoch=-1):
163
+ self.warmup_steps = warmup_steps
164
+ self.scheduler_steps = scheduler_steps
165
+ self.min_ratio = min_ratio
166
+ # self.fixed_lr = fixed_lr
167
+ super(WarmupLinearScheduler, self).__init__(
168
+ optimizer, self.lr_lambda, last_epoch=last_epoch
169
+ )
170
+
171
+ def lr_lambda(self, step):
172
+ if step < self.warmup_steps:
173
+ return (1 - self.min_ratio) * step / float(max(1, self.warmup_steps)) + self.min_ratio
174
+
175
+ # if self.fixed_lr:
176
+ # return 1.0
177
+
178
+ return max(0.0,
179
+ 1.0 + (self.min_ratio - 1) * (step - self.warmup_steps) / float(max(1.0, self.scheduler_steps - self.warmup_steps)),
180
+ )
181
+
182
+
183
+ class Loss_log():
184
+ def __init__(self):
185
+ self.loss = []
186
+ self.acc = [0.]
187
+ self.flag = 0
188
+ self.token_right_num = []
189
+ self.token_all_num = []
190
+ self.word_right_num = []
191
+ self.word_all_num = []
192
+ # 默认不使用top_k acc
193
+ self.use_top_k_acc = 0
194
+
195
+ def acc_init(self, topn=[1]):
196
+ self.loss = []
197
+ self.token_right_num = []
198
+ self.token_all_num = []
199
+ self.topn = topn
200
+ self.use_top_k_acc = 1
201
+ self.top_k_word_right = {}
202
+ for n in topn:
203
+ self.top_k_word_right[n] = []
204
+
205
+ def time_init(self):
206
+ self.start = time()
207
+ self.last = self.start
208
+ self.time_used_epoch = []
209
+
210
+ def time_cpt(self, step, total_step):
211
+ # 时间统计
212
+ time_used_last_epoch = time() - self.last
213
+ self.time_used_epoch.append(time_used_last_epoch)
214
+ time_used = time() - self.start
215
+ self.last = time()
216
+ h, m, s = time_trans(time_used)
217
+ time_remain = int(total_step - step) * mean(self.time_used_epoch)
218
+ h_r, m_r, s_r = time_trans(time_remain)
219
+
220
+ return h, m, s, h_r, m_r, s_r
221
+
222
+ def get_token_acc(self):
223
+ # 返回list
224
+ if len(self.token_all_num) == 0:
225
+ return 0.
226
+ elif self.use_top_k_acc == 1:
227
+ res = []
228
+ for n in self.topn:
229
+ res.append(round((sum(self.top_k_word_right[n]) / sum(self.token_all_num)) * 100 , 3))
230
+ return res
231
+ else:
232
+ return [sum(self.token_right_num)/sum(self.token_all_num)]
233
+
234
+
235
+ def update_token(self, token_num, token_right):
236
+ # 输入是list文件
237
+ self.token_all_num.append(token_num)
238
+ if isinstance(token_right, list):
239
+ for i, n in enumerate(self.topn):
240
+ self.top_k_word_right[n].append(token_right[i])
241
+ self.token_right_num.append(token_right)
242
+
243
+ def update(self, case):
244
+ self.loss.append(case)
245
+
246
+ def update_acc(self, case):
247
+ self.acc.append(case)
248
+
249
+ def get_loss(self):
250
+ if len(self.loss) == 0:
251
+ return 500.
252
+ return mean(self.loss)
253
+
254
+ def get_acc(self):
255
+ return self.acc[-1]
256
+
257
+ def get_min_loss(self):
258
+ return min(self.loss)
259
+
260
+ def early_stop(self):
261
+ # min_loss = min(self.loss)
262
+ if self.loss[-1] > min(self.loss):
263
+ self.flag += 1
264
+ else:
265
+ self.flag = 0
266
+
267
+ if self.flag > 1000:
268
+ return True
269
+ else:
270
+ return False
271
+
272
+
273
+ def add_special_token(tokenizer, model=None, rank=0, cache_path = None):
274
+ # model: bert layer
275
+ # 每次更新这个,所有模型需要重新训练,get_chinese_ref.py需要重新运行
276
+ # 主函数调用该函数的位置需要在载入模型之前
277
+ # ---------------------------------------
278
+ # 不会被mask的 token, 不参与 任何时候的MASK
279
+ special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]']
280
+
281
+ # ---------------------------------------
282
+ # 会被mask的但是---#不加入#---tokenizer的内容
283
+ # 出现次数多(>10000)但是长度较长(>=4符)
284
+ # 或者是一些难以理解的名词
285
+ # WWM 的主体
286
+ # TODO: 专家检查
287
+ # To Add: 'SGSN', '3GPP', 'Bearer', 'sbim', 'FusionSphere', 'IMSI', 'GGSN', 'RETCODE', 'PCRF', 'PDP', 'GTP', 'OCS', 'HLR', 'FFFF', 'VLR', 'DNN', 'PID', 'CSCF', 'PDN', 'SCTP', 'SPGW', 'TAU', 'PCEF', 'NSA', 'ACL', 'BGP', 'USCDB', 'VoLTE', 'RNC', 'GPRS', 'DRA', 'MOC'
288
+ # 拆分:配置原则,本端规划
289
+ norm_token = ['网元实例', '事件类型', '告警级别', '告警名称', '告警源', '通讯系统', '默认值', '链路故障', '取值范围', '可选必选说明', '数据来源', '用户平面', '配置', '原则', '该参数', '失败次数', '可选参数', 'S1模式', '必选参数', 'IP地址', '响应消息', '成功次数', '测量指标', '用于', '统计周期', '该命令', '上下文', '请求次数', '本端', 'pod', 'amf', 'smf', 'nrf', 'ausf', 'upcf', 'upf', 'udm', 'PDU', 'alias', 'PLMN', 'MML', 'Info_Measure', 'icase', 'Diameter', 'MSISDN', 'RAT', 'RMV', 'PFCP', 'NSSAI', 'CCR', 'HDBNJjs', 'HNGZgd', 'SGSN', '3GPP', 'Bearer', 'sbim', 'FusionSphere', 'IMSI', 'GGSN', 'RETCODE', 'PCRF', 'PDP', 'GTP', 'OCS', 'HLR', 'FFFF', 'VLR', 'DNN', 'PID', 'CSCF', 'PDN', 'SCTP', 'SPGW', 'TAU', 'PCEF', 'NSA', 'ACL', 'BGP', 'USCDB', 'VoLTE', 'RNC', 'GPRS', 'DRA', 'MOC', '告警', '网元', '对端', '信令', '话单', '操作', '风险', '等级', '下发', '流控', '运营商', '寻呼', '漫游', '切片', '报文', '号段', '承载', '批量', '导致', '原因是', '影响', '造成', '引起', '随之', '情况下', '根因', 'trigger']
290
+ # ---------------------------------------
291
+ # , '', '', '', '', '', '', '', '', '', '', ''
292
+ # 会被mask的但是---#加入#---tokenizer的内容
293
+ # 长度小于等于3,缩写/专有名词 大于10000次
294
+ # 严谨性要求大于norm_token
295
+ # 出现次数多时有足够的影响力可以进行分离
296
+ norm_token_tobe_added = ['pod', 'amf', 'smf', 'nrf', 'ausf', 'upcf', 'upf', 'udm', 'ALM', '告警', '网元', '对端', '信令', '话单', 'RAN', 'MML', 'PGW', 'MME', 'SGW', 'NF', 'APN', 'LST', 'GW', 'QoS', 'IPv', 'PDU', 'IMS', 'EPS', 'GTP', 'PDP', 'LTE', 'HSS']
297
+
298
+ token_tobe_added = []
299
+ # all_token = special_token + norm_token_tobe_added
300
+ all_token = norm_token_tobe_added
301
+ for i in all_token:
302
+ if i not in tokenizer.vocab.keys() and i.lower() not in tokenizer.vocab.keys():
303
+ token_tobe_added.append(i)
304
+
305
+ # tokenizer.add_tokens(special_token, special_tokens=False)
306
+ # tokenizer.add_tokens(norm_token, special_tokens=False)
307
+ tokenizer.add_tokens(token_tobe_added, special_tokens=False)
308
+ special_tokens_dict = {"additional_special_tokens": special_token}
309
+ special_token_ = tokenizer.add_special_tokens(special_tokens_dict)
310
+ if rank == 0:
311
+ print("Added tokens:")
312
+ print(tokenizer.get_added_vocab())
313
+
314
+ # pdb.set_trace()
315
+
316
+ if model is not None:
317
+ # TODO: 用预训练好的TeleBert进行这部分embedding(所有添加的embedding)的初始化
318
+ if rank == 0:
319
+ print(f"--------------------------------")
320
+ print(f"-------- orig word embedding shape: {model.get_input_embeddings().weight.shape}")
321
+ sz = model.resize_token_embeddings(len(tokenizer))
322
+ if cache_path is not None:
323
+ # model.cpu()
324
+ token_2_emb = torch.load(cache_path)
325
+ # 在这里加入embedding 初始化之后需要tie一下
326
+ token_dic = tokenizer.get_added_vocab()
327
+ id_2_token = {v:k for k,v in token_dic.items()}
328
+ with torch.no_grad():
329
+ for key in id_2_token.keys():
330
+ model.bert.embeddings.word_embeddings.weight[key,:] = nn.Parameter(token_2_emb[id_2_token[key]][0]).cuda()
331
+ # model.get_input_embeddings().weight[key,:] = nn.Parameter(token_2_emb[id_2_token[key]][0]).cuda()
332
+ # model.embedding
333
+ model.bert.tie_weights()
334
+ if rank == 0:
335
+ print(f"-------- resize_token_embeddings into {sz} done!")
336
+ print(f"--------------------------------")
337
+ # 这里替换embedding
338
+
339
+ norm_token = list(set(norm_token).union(set(norm_token_tobe_added)))
340
+ return tokenizer, special_token, norm_token
341
+
342
+
343
+ def time_trans(sec):
344
+ m, s = divmod(sec, 60)
345
+ h, m = divmod(m, 60)
346
+ return int(h), int(m), int(s)
347
+
348
+ def torch_accuracy(output, target, topk=(1,)):
349
+ '''
350
+ param output, target: should be torch Variable
351
+ '''
352
+ # assert isinstance(output, torch.cuda.Tensor), 'expecting Torch Tensor'
353
+ # assert isinstance(target, torch.Tensor), 'expecting Torch Tensor'
354
+ # print(type(output))
355
+
356
+ topn = max(topk)
357
+ batch_size = output.size(0)
358
+
359
+ _, pred = output.topk(topn, 1, True, True) # 返回(values,indices)其中indices就是预测类别的值,0为第一类
360
+ pred = pred.t() # torch.t()转置,既可得到每一行为batch最好的一个预测序列
361
+
362
+ is_correct = pred.eq(target.view(1, -1).expand_as(pred))
363
+
364
+ ans = []
365
+ ans_num = []
366
+ for i in topk:
367
+ # is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True)
368
+ is_correct_i = is_correct[:i].contiguous().view(-1).float().sum(0, keepdim=True)
369
+ ans_num.append(int(is_correct_i.item()))
370
+ ans.append(is_correct_i.mul_(100.0 / batch_size))
371
+
372
+ return ans, ans_num
373
+
374
+
KTeleBERT/test.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python main.py --only_test 1 \
2
+ --batch_size 150 \
3
+ --use_NumEmb 1 \
4
+ --mask_test 0 \
5
+ --mask_stratege wwm \
6
+ --model_name model_name_vXX \
7
+ --ke_test 0 \
8
+ --embed_gen 1 \
9
+ --train_ratio 0 \
10
+ --ke_dim 256 \
11
+ --plm_emb_type cls \
12
+
KTeleBERT/torchlight/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logger import initialize_exp, get_dump_path
2
+ from .metric import Metric, Top_K_Metric
3
+ from .module import LSTM4VarLenSeq
4
+ from .vocab import (PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN,
5
+ DefaultLookupDict,
6
+ Vocabulary)
7
+ from .utils import (invert_dict,
8
+ personal_display_settings,
9
+ set_seed,
10
+ normalize,
11
+ snapshot,
12
+ show_params,
13
+ longest_substring,
14
+ pad,
15
+ to_cuda,
16
+ get_code_version,
17
+ cat_ragged_tensors,
18
+ topk_accuracy,
19
+ get_total_trainable_params)
20
+
KTeleBERT/torchlight/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (800 Bytes). View file
 
KTeleBERT/torchlight/__pycache__/logger.cpython-38.pyc ADDED
Binary file (3.81 kB). View file
 
KTeleBERT/torchlight/__pycache__/metric.cpython-38.pyc ADDED
Binary file (3.82 kB). View file
 
KTeleBERT/torchlight/__pycache__/module.cpython-38.pyc ADDED
Binary file (3.45 kB). View file
 
KTeleBERT/torchlight/__pycache__/utils.cpython-38.pyc ADDED
Binary file (6.57 kB). View file
 
KTeleBERT/torchlight/__pycache__/vocab.cpython-38.pyc ADDED
Binary file (5.5 kB). View file
 
KTeleBERT/torchlight/logger.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import time
5
+ import json
6
+ import torch
7
+ import pickle
8
+ import random
9
+ import getpass
10
+ import logging
11
+ import argparse
12
+ import subprocess
13
+ import numpy as np
14
+ from datetime import timedelta, date
15
+ from .utils import get_code_version
16
+
17
+
18
+ class LogFormatter():
19
+
20
+ def __init__(self):
21
+ self.start_time = time.time()
22
+
23
+ def format(self, record):
24
+ elapsed_seconds = round(record.created - self.start_time)
25
+
26
+ prefix = "%s - %s - %s" % (
27
+ record.levelname,
28
+ time.strftime('%x %X'),
29
+ timedelta(seconds=elapsed_seconds)
30
+ )
31
+ message = record.getMessage()
32
+ message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
33
+ return "%s - %s" % (prefix, message) if message else ''
34
+
35
+
36
+ def create_logger(filepath, rank):
37
+ """
38
+ Create a logger.
39
+ Use a different log file for each process.
40
+ """
41
+ # create log formatter
42
+ log_formatter = LogFormatter()
43
+
44
+ # create file handler and set level to debug
45
+ if filepath is not None:
46
+ if rank > 0:
47
+ filepath = '%s-%i' % (filepath, rank)
48
+ file_handler = logging.FileHandler(filepath, "a", encoding='utf-8')
49
+ file_handler.setLevel(logging.DEBUG)
50
+ file_handler.setFormatter(log_formatter)
51
+
52
+ # create console handler and set level to info
53
+ console_handler = logging.StreamHandler()
54
+ console_handler.setLevel(logging.INFO)
55
+ console_handler.setFormatter(log_formatter)
56
+
57
+ # create logger and set level to debug
58
+ logger = logging.getLogger()
59
+ logger.handlers = []
60
+ logger.setLevel(logging.DEBUG)
61
+ logger.propagate = False
62
+ if filepath is not None:
63
+ logger.addHandler(file_handler)
64
+ logger.addHandler(console_handler)
65
+
66
+ # reset logger elapsed time
67
+ def reset_time():
68
+ log_formatter.start_time = time.time()
69
+ logger.reset_time = reset_time
70
+
71
+ return logger
72
+
73
+
74
+ def initialize_exp(params):
75
+ """
76
+ Initialize the experiment:
77
+ - dump parameters
78
+ - create a logger
79
+ """
80
+ # dump parameters
81
+ exp_folder = get_dump_path(params)
82
+ json.dump(vars(params), open(os.path.join(exp_folder, 'params.pkl'), 'w'), indent=4)
83
+
84
+ # get running command
85
+ command = ["python", sys.argv[0]]
86
+ for x in sys.argv[1:]:
87
+ if x.startswith('--'):
88
+ assert '"' not in x and "'" not in x
89
+ command.append(x)
90
+ else:
91
+ assert "'" not in x
92
+ if re.match('^[a-zA-Z0-9_]+$', x):
93
+ command.append("%s" % x)
94
+ else:
95
+ command.append("'%s'" % x)
96
+ command = ' '.join(command)
97
+ params.command = command + ' --exp_id "%s"' % params.exp_id
98
+
99
+ # check experiment name
100
+ assert len(params.exp_name.strip()) > 0
101
+
102
+ # create a logger
103
+ logger = create_logger(os.path.join(exp_folder, 'train.log'), rank=getattr(params, 'global_rank', 0))
104
+ logger.info("============ Initialized logger ============")
105
+ # logger.info("\n".join("%s: %s" % (k, str(v))
106
+ # for k, v in sorted(dict(vars(params)).items())))
107
+ # text = f'# Git Version: {get_code_version()} #'
108
+ # logger.info("\n".join(['=' * 24, text, '=' * 24]))
109
+ logger.info("The experiment will be stored in %s\n" % exp_folder)
110
+ logger.info("Running command: %s" % command)
111
+ logger.info("")
112
+ return logger
113
+
114
+
115
+ def get_dump_path(params):
116
+ """
117
+ Create a directory to store the experiment.
118
+ """
119
+ assert len(params.exp_name) > 0
120
+ assert not params.dump_path in ('', None), \
121
+ 'Please choose your favorite destination for dump.'
122
+ dump_path = params.dump_path
123
+
124
+ # create the sweep path if it does not exist
125
+ when = date.today().strftime('%m%d-')
126
+ sweep_path = os.path.join(dump_path, when + params.exp_name)
127
+ if not os.path.exists(sweep_path):
128
+ subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()
129
+
130
+ # create an random ID for the job if it is not given in the parameters.
131
+ if params.exp_id == '':
132
+ chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
133
+ while True:
134
+ exp_id = ''.join(random.choice(chars) for _ in range(10))
135
+ if not os.path.isdir(os.path.join(sweep_path, exp_id)):
136
+ break
137
+ params.exp_id = exp_id
138
+
139
+ # create the dump folder / update parameters
140
+ exp_folder = os.path.join(sweep_path, params.exp_id)
141
+ if not os.path.isdir(exp_folder):
142
+ subprocess.Popen("mkdir -p %s" % exp_folder, shell=True).wait()
143
+ return exp_folder
144
+
145
+
146
+ if __name__ == '__main__':
147
+ pass
KTeleBERT/torchlight/metric.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from abc import ABC, ABCMeta, abstractclassmethod
2
+ import torch
3
+ import numpy as np
4
+ from abc import ABC, abstractmethod, ABCMeta
5
+
6
+ class Metric(metaclass=ABCMeta):
7
+ """
8
+ - reset() in the begining of every epoch.
9
+ - update_per_batch() after every batch.
10
+ - update_per_epoch() after every epoch.
11
+ """
12
+
13
+ @abstractmethod
14
+ def __init__(self):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def reset(self):
19
+ pass
20
+
21
+ @abstractmethod
22
+ def update_per_batch(self, output):
23
+ pass
24
+
25
+ @abstractmethod
26
+ def update_per_epoch(self):
27
+ pass
28
+
29
+ class Top_K_Metric(Metric):
30
+ """
31
+ Stores accuracy (score), loss and timing info
32
+ """
33
+ def __init__(self, topnum=[1,3,10]):
34
+ super().__init__()
35
+ # assert len(topnum) == 3
36
+ self.topnum = topnum
37
+ self.k_num = len(self.topnum)
38
+ self.reset()
39
+
40
+ def reset(self):
41
+ self.total_loss = 0
42
+ self.correct_list = [0] * self.k_num
43
+ self.acc_list = [0] * self.k_num
44
+ self.acc_all = 0
45
+ self.num_examples = 0
46
+ self.num_epoch = 0
47
+
48
+ self.mrr = 0
49
+ self.mr = 0
50
+ self.mrr_all = 0
51
+ self.mr_all = 0
52
+
53
+ def update_per_batch(self, loss, ans, pred):
54
+ self.total_loss += loss
55
+ self.num_epoch += 1
56
+ self.top_k_list = self.batch_accuracy(pred, ans)
57
+ self.num_examples += self.top_k_list[0].shape[0]
58
+ for i in range(self.k_num):
59
+ self.correct_list[i] += self.top_k_list[i].sum().item()
60
+
61
+ # mrr
62
+ mrr_tmp, mr_tmp = self.batch_mr_mrr(pred, ans)
63
+ self.mrr_all += mrr_tmp.sum().item()
64
+ self.mr_all += mr_tmp.sum().item()
65
+
66
+
67
+
68
+ def update_per_epoch(self):
69
+ for i in range(self.k_num):
70
+ self.acc_list[i] = 100 * (self.correct_list[i] / self.num_examples)
71
+
72
+ self.mr = self.mr_all / self.num_examples
73
+ self.mrr = self.mrr_all / self.num_examples
74
+ self.total_loss = self.total_loss / self.num_epoch
75
+ self.acc_all = sum(self.acc_list)
76
+
77
+
78
+ def batch_accuracy(self, predicted, true):
79
+ """ Compute the accuracies for a batch of predictions and answers """
80
+ if len(true.shape) == 3:
81
+ true = true[0]
82
+ _, ok = predicted.topk(max(self.topnum), dim=1)
83
+ agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
84
+ top_k_list = [0]*self.topnum
85
+ for i in range(max(self.topnum)):
86
+ tmp = ok[:, i].reshape(-1, 1)
87
+ agreeing_all += true.gather(dim=1, index=tmp)
88
+ for k in range(self.k_num):
89
+ if i == self.topnum[k] - 1:
90
+ top_k_list[k] = (agreeing_all * 0.3).clamp(max=1)
91
+ break
92
+
93
+ return top_k_list
94
+
95
+
96
+
97
+ def batch_mr_mrr(self, predicted, true):
98
+ if len(true.shape) == 3:
99
+ true = true[0]
100
+
101
+ # 计算
102
+ top_rank = predicted.shape[1]
103
+ batch_size = predicted.shape[0]
104
+ _, predict_ans_rank = predicted.topk(top_rank, dim=1) # 答案排名的坐标 batchsize * 500
105
+ _, real_ans = true.topk(1, dim=1) # 真正的答案:batchsize * 1
106
+
107
+ # 扩充维度
108
+ real_ans = real_ans.expand(batch_size, top_rank)
109
+ ans_different = torch.abs(predict_ans_rank - real_ans)
110
+ # 此时为0的位置就是预测正确的位置
111
+ _, real_ans_list = ans_different.topk(top_rank, dim=1) #此时最后一位的数值就是正确答案在预测答案里面的位置,为 0
112
+ real_ans_list = real_ans_list + 1.0
113
+ mr = real_ans_list[:,-1].reshape(-1,1).to(torch.float64)
114
+ mrr = 1.0 / mr
115
+ # pdb.set_trace()
116
+
117
+ return mrr,mr
118
+
119
+
120
+ if __name__ == '__main__':
121
+ pass
KTeleBERT/torchlight/module.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Sequence, Union, Callable
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8
+
9
+ torch.manual_seed(10086)
10
+ # typing, everything in Python is Object.
11
+ tensor_activation = Callable[[torch.Tensor], torch.Tensor]
12
+
13
+
14
+ class LSTM4VarLenSeq(nn.Module):
15
+ def __init__(self, input_size, hidden_size,
16
+ num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True):
17
+ """
18
+ no dropout support
19
+ batch_first support deprecated, the input and output tensors are
20
+ provided as (batch, seq_len, feature).
21
+
22
+ Args:
23
+ input_size:
24
+ hidden_size:
25
+ num_layers:
26
+ bias:
27
+ bidirectional:
28
+ init: ways to init the torch.nn.LSTM parameters,
29
+ supports 'orthogonal' and 'uniform'
30
+ take_last: 'True' if you only want the final hidden state
31
+ otherwise 'False'
32
+ """
33
+ super(LSTM4VarLenSeq, self).__init__()
34
+ self.lstm = nn.LSTM(input_size=input_size,
35
+ hidden_size=hidden_size,
36
+ num_layers=num_layers,
37
+ bias=bias,
38
+ bidirectional=bidirectional)
39
+ self.input_size = input_size
40
+ self.hidden_size = hidden_size
41
+ self.num_layers = num_layers
42
+ self.bias = bias
43
+ self.bidirectional = bidirectional
44
+ self.init = init
45
+ self.take_last = take_last
46
+ self.batch_first = True # Please don't modify this
47
+
48
+ self.init_parameters()
49
+
50
+ def init_parameters(self):
51
+ """orthogonal init yields generally good results than uniform init"""
52
+ if self.init == 'orthogonal':
53
+ gain = 1 # use default value
54
+ for nth in range(self.num_layers * self.bidirectional):
55
+ # w_ih, (4 * hidden_size x input_size)
56
+ nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain)
57
+ # w_hh, (4 * hidden_size x hidden_size)
58
+ nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain)
59
+ # b_ih, (4 * hidden_size)
60
+ nn.init.zeros_(self.lstm.all_weights[nth][2])
61
+ # b_hh, (4 * hidden_size)
62
+ nn.init.zeros_(self.lstm.all_weights[nth][3])
63
+ elif self.init == 'uniform':
64
+ k = math.sqrt(1 / self.hidden_size)
65
+ for nth in range(self.num_layers * self.bidirectional):
66
+ nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k)
67
+ nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k)
68
+ nn.init.zeros_(self.lstm.all_weights[nth][2])
69
+ nn.init.zeros_(self.lstm.all_weights[nth][3])
70
+ else:
71
+ raise NotImplemented('Unsupported Initialization')
72
+
73
+ def forward(self, x, x_len, hx=None):
74
+ # 1. Sort x and its corresponding length
75
+ sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True)
76
+ sorted_x = x[sorted_x_idx]
77
+ # 2. Ready to unsort after LSTM forward pass
78
+ # Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does.
79
+ _, unsort_x_idx = torch.sort(sorted_x_idx, descending=False)
80
+
81
+ # 3. Pack the sorted version of x and x_len, as required by the API.
82
+ x_emb = pack_padded_sequence(sorted_x, sorted_x_len,
83
+ batch_first=self.batch_first)
84
+
85
+ # 4. Forward lstm
86
+ # output_packed.data.shape is (valid_seq, num_directions * hidden_dim).
87
+ # See doc of torch.nn.LSTM for details.
88
+ out_packed, (hn, cn) = self.lstm(x_emb)
89
+
90
+ # 5. unsort h
91
+ # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
92
+ hn = hn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
93
+ hn = hn.permute(1, 0, 2) # swap the first two again to recover
94
+ if self.take_last:
95
+ return hn.squeeze(0)
96
+ else:
97
+ # unpack: out
98
+ # (batch, max_seq_len, num_directions * hidden_size)
99
+ out, _ = pad_packed_sequence(out_packed,
100
+ batch_first=self.batch_first)
101
+ out = out[unsort_x_idx]
102
+ # unpack: c
103
+ # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
104
+ cn = cn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
105
+ cn = cn.permute(1, 0, 2) # swap the first two again to recover
106
+ return out, (hn, cn)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ # Note that in the future we will import unittest
111
+ # and port the following examples to test folder.
112
+
113
+ # Unit test for LSTM variable length sequences
114
+ # ================
115
+ net = LSTM4VarLenSeq(200, 100,
116
+ num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False)
117
+
118
+ inputs = torch.tensor([[1, 2, 3, 0],
119
+ [2, 3, 0, 0],
120
+ [2, 4, 3, 0],
121
+ [1, 4, 3, 0],
122
+ [1, 2, 3, 4]])
123
+ embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0)
124
+ lens = torch.LongTensor([3, 2, 3, 3, 4])
125
+
126
+ input_embed = embedding(inputs)
127
+ output, (h, c) = net(input_embed, lens)
128
+ # 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer)
129
+ print(output.shape)
130
+ # 6, 5, 100, num_layers * num_directions, batch, hidden_size
131
+ print(h.shape)
132
+ # 6, 5, 100, num_layers * num_directions, batch, hidden_size
133
+ print(c.shape)
KTeleBERT/torchlight/utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilizations for common usages.
3
+ """
4
+ import os
5
+ import random
6
+ import torch
7
+ import numpy as np
8
+ from difflib import SequenceMatcher
9
+ from unidecode import unidecode
10
+ from datetime import datetime
11
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
12
+
13
+
14
+ def invert_dict(d):
15
+ return {v: k for k, v in d.items()}
16
+
17
+ def personal_display_settings():
18
+ """
19
+ Pandas Doc
20
+ https://pandas.pydata.org/pandas-docs/stable/generated/pandas.set_option.html
21
+ NumPy Doc
22
+ -
23
+ """
24
+ from pandas import set_option
25
+ set_option('display.max_rows', 500)
26
+ set_option('display.max_columns', 500)
27
+ set_option('display.width', 2000)
28
+ set_option('display.max_colwidth', 1000)
29
+ from numpy import set_printoptions
30
+ set_printoptions(suppress=True)
31
+
32
+
33
+ def set_seed(seed):
34
+ """
35
+ Freeze every seed for reproducibility.
36
+ torch.cuda.manual_seed_all is useful when using random generation on GPUs.
37
+ e.g. torch.cuda.FloatTensor(100).uniform_()
38
+ """
39
+ os.environ['PYTHONHASHSEED'] = str(seed)
40
+ random.seed(seed)
41
+ np.random.seed(seed)
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ torch.cuda.manual_seed_all(seed)
45
+
46
+
47
+
48
+ def normalize(s):
49
+ """
50
+ German and Frence have different vowels than English.
51
+ This utilization removes all the non-unicode characters.
52
+ Example:
53
+ āáǎà --> aaaa
54
+ ōóǒò --> oooo
55
+ ēéěè --> eeee
56
+ īíǐì --> iiii
57
+ ūúǔù --> uuuu
58
+ ǖǘǚǜ --> uuuu
59
+
60
+ :param s: unicode string
61
+ :return: unicode string with regular English characters.
62
+ """
63
+ s = s.strip().lower()
64
+ s = unidecode(s)
65
+ return s
66
+
67
+
68
+ def snapshot(model, epoch, save_path):
69
+ """
70
+ Saving models w/ its params.
71
+ Get rid of the ONNX Protocal.
72
+ F-string feature new in Python 3.6+ is used.
73
+ """
74
+ os.makedirs(save_path, exist_ok=True)
75
+ # timestamp = datetime.now().strftime('%m%d_%H%M')
76
+ save_path = os.path.join(save_path, f'{type(model).__name__}_{epoch}_epoch.pkl')
77
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
78
+ torch.save(model.module.state_dict(), save_path)
79
+ else:
80
+ torch.save(model.state_dict(), save_path)
81
+ return save_path
82
+
83
+
84
+ def save_checkpoint(model, optimizer, epoch, path):
85
+ torch.save({
86
+ 'epoch': epoch,
87
+ 'models': model.state_dict(),
88
+ 'optimizer': optimizer.state_dict(),
89
+ }, path)
90
+
91
+
92
+ def load_checkpoint(path, map_location):
93
+ checkpoint = torch.load(path, map_location=map_location)
94
+ return checkpoint
95
+
96
+
97
+ def show_params(model):
98
+ """
99
+ Show models parameters for logging.
100
+ """
101
+ for name, param in model.named_parameters():
102
+ print('%-16s' % name, param.size())
103
+
104
+
105
+ def longest_substring(str1, str2):
106
+ # initialize SequenceMatcher object with input string
107
+ seqMatch = SequenceMatcher(None, str1, str2)
108
+
109
+ # find match of longest sub-string
110
+ # output will be like Match(a=0, b=0, size=5)
111
+ match = seqMatch.find_longest_match(0, len(str1), 0, len(str2))
112
+
113
+ # print longest substring
114
+ return str1[match.a: match.a + match.size] if match.size != 0 else ""
115
+
116
+
117
+ def pad(sent, max_len):
118
+ """
119
+ syntax "[0] * int" only works properly for Python 3.5+
120
+ Note that in testing time, the length of a sentence
121
+ might exceed the pre-defined max_len (of training data).
122
+ """
123
+ length = len(sent)
124
+ return (sent + [0] * (max_len - length))[:max_len] if length < max_len else sent[:max_len]
125
+
126
+
127
+ def to_cuda(*args, device=None):
128
+ """
129
+ Move Tensors to CUDA.
130
+ If no device provided, default to the first card in CUDA_VISIBLE_DEVICES.
131
+ """
132
+ assert all(torch.is_tensor(t) for t in args), \
133
+ 'Only support for tensors, please check if any nn.Module exists.'
134
+ if device is None:
135
+ device = torch.device('cuda:0')
136
+ return [None if x is None else x.to(device) for x in args]
137
+
138
+
139
+ def get_code_version(short_sha=True):
140
+ from subprocess import check_output, STDOUT, CalledProcessError
141
+ try:
142
+ sha = check_output('git rev-parse HEAD', stderr=STDOUT,
143
+ shell=True, encoding='utf-8')
144
+ if short_sha:
145
+ sha = sha[:7]
146
+ return sha
147
+ except CalledProcessError:
148
+ # There was an error - command exited with non-zero code
149
+ pwd = check_output('pwd', stderr=STDOUT, shell=True, encoding='utf-8')
150
+ pwd = os.path.abspath(pwd).strip()
151
+ print(f'Working dir {pwd} is not a git repo.')
152
+
153
+
154
+ def cat_ragged_tensors(left, right):
155
+ assert left.size(0) == right.size(0)
156
+ batch_size = left.size(0)
157
+ max_len = left.size(1) + right.size(1)
158
+
159
+ len_left = (left != 0).sum(dim=1)
160
+ len_right = (right != 0).sum(dim=1)
161
+
162
+ left_seq = left.unbind()
163
+ right_seq = right.unbind()
164
+ # handle zero padding
165
+ output = torch.zeros((batch_size, max_len), dtype=torch.long, device=left.device)
166
+ for i, row_left, row_right, l1, l2 in zip(range(batch_size),
167
+ left_seq, right_seq,
168
+ len_left, len_right):
169
+ l1 = l1.item()
170
+ l2 = l2.item()
171
+ j = l1 + l2
172
+ # concatenate rows of ragged tensors
173
+ row_cat = torch.cat((row_left[:l1], row_right[:l2]))
174
+ # copy to empty tensor
175
+ output[i, :j] = row_cat
176
+ return output
177
+
178
+
179
+ def topk_accuracy(inputs, labels, k=1, largest=True):
180
+ assert len(inputs.size()) == 2
181
+ assert len(labels.size()) == 2
182
+ _, indices = inputs.topk(k=k, largest=largest)
183
+ result = indices - labels # boardcast
184
+ nonzero_count = (result != 0).sum(dim=1, keepdim=True)
185
+ num_correct = (nonzero_count != result.size(1)).sum().item()
186
+ num_example = inputs.size(0)
187
+ return num_correct, num_example
188
+
189
+
190
+ def get_total_trainable_params(model):
191
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ print(normalize('ǖǘǚǜ'))
KTeleBERT/torchlight/vocab.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ Every NLP task needs a Vocabulary
4
+ Every Vocabulary is built from Instances
5
+ Every Instance is a collection of Fields
6
+ """
7
+
8
+ __all__ = ['DefaultLookupDict', 'Vocabulary']
9
+
10
+ PAD_TOKEN = '<pad>'
11
+ UNK_TOKEN = '<unk>'
12
+ BOS_TOKEN = '<bos>'
13
+ EOS_TOKEN = '<eos>'
14
+ PAD_IDX = 0
15
+ UNK_IDX = 1
16
+
17
+
18
+ class DefaultLookupDict(dict):
19
+ def __init__(self, default):
20
+ super(DefaultLookupDict, self).__init__()
21
+ self._default = default
22
+
23
+ def __getitem__(self, item):
24
+ return self.get(item, self._default)
25
+
26
+
27
+ class Vocabulary:
28
+ """
29
+ Define a vocabulary object that will be used to numericalize a field.
30
+ Attributes:
31
+ token2id: A collections.defaultdict instance mapping token strings to
32
+ numerical identifiers.
33
+ id2token: A list of token strings indexed by their numerical
34
+ identifiers.
35
+ embedding: pretrained vectors.
36
+
37
+ Examples:
38
+ >>> from torchlight.vocab import Vocabulary
39
+ >>> from collections import Counter
40
+ >>> text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
41
+ >>> vocab = Vocabulary(Counter(text_data))
42
+ """
43
+ def __init__(self, counter, max_size=None, min_freq=1, specials=None):
44
+ """
45
+ Create a Vocabulary given Counter.
46
+ Args:
47
+ counter: collections.Counter object holding the frequencies of
48
+ each value found in the data.
49
+ max_size: The maximum size of the vocabulary, or None for no
50
+ maximum. Default: None.
51
+ min_freq: The minimum frequency needed to include a token in the
52
+ vocabulary. Values less than 1 will be set to 1. Default: 1.
53
+ specials: The list of special tokens except ['<pad>', '<unk>'].
54
+ Possible choices: [CLS] [MASK] [SEP] in BERT or <bos> <eos>
55
+ in Machine Translation.
56
+ """
57
+ min_freq = max(min_freq, 1) # must be positive
58
+
59
+ if specials is None:
60
+ self.specials = [PAD_TOKEN, UNK_TOKEN]
61
+ else:
62
+ assert isinstance(specials, list), "'specials' is of type list"
63
+ self.specials = [PAD_TOKEN, UNK_TOKEN] + specials
64
+
65
+ assert len(set(self.specials)) == len(self.specials), \
66
+ "specials can not contain duplicates."
67
+
68
+ if max_size is not None:
69
+ max_size = len(self.specials) + max_size
70
+
71
+ self.id2token = self.specials[:]
72
+ self.token2id = DefaultLookupDict(UNK_IDX)
73
+ self.token2id.update({tok: i for i, tok in enumerate(self.id2token)})
74
+
75
+ # sort by frequency, then alphabetically
76
+ token_freqs = sorted(counter.items(), key=lambda tup: tup[0])
77
+ token_freqs.sort(key=lambda tup: tup[1], reverse=True)
78
+
79
+ for token, freq in token_freqs:
80
+ if freq < min_freq or len(self.id2token) == max_size:
81
+ break
82
+ if token not in self.specials:
83
+ self.id2token.append(token)
84
+ self.token2id[token] = len(self.id2token) - 1
85
+
86
+ # TODO
87
+ self.embedding = None
88
+
89
+ def __len__(self):
90
+ return len(self.id2token)
91
+
92
+ def __repr__(self):
93
+ return 'Vocab(size={}, specials="{}")'.format(len(self), self.specials)
94
+
95
+ def __getitem__(self, tokens):
96
+ """Looks up indices of text tokens according to the vocabulary.
97
+ If `unknown_token` of the vocabulary is None, looking up unknown tokens
98
+ results in KeyError.
99
+ Parameters
100
+ ----------
101
+ tokens : str or list of strs
102
+ A source token or tokens to be converted.
103
+ Returns
104
+ -------
105
+ int or list of ints
106
+ A token index or a list of token indices according to the vocabulary.
107
+ """
108
+
109
+ if not isinstance(tokens, (list, tuple)):
110
+ return self.token2id[tokens]
111
+ else:
112
+ return [self.token2id[token] for token in tokens]
113
+
114
+ def __call__(self, tokens):
115
+ """Looks up indices of text tokens according to the vocabulary.
116
+ Parameters
117
+ ----------
118
+ tokens : str or list of strs
119
+ A source token or tokens to be converted.
120
+ Returns
121
+ -------
122
+ int or list of ints
123
+ A token index or a list of token indices according to the
124
+ vocabulary.
125
+ """
126
+
127
+ return self[tokens]
128
+
129
+ @classmethod
130
+ def from_json(cls, json_str):
131
+ pass
132
+
133
+ def to_json(self):
134
+ pass
135
+
136
+ def set_embedding(self):
137
+ pass