kevin110211
commited on
Commit
·
5d58b52
1
Parent(s):
bc0481a
Upload 51 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- KTeleBERT/__pycache__/config.cpython-38.pyc +0 -0
- KTeleBERT/config.py +234 -0
- KTeleBERT/data_trans.py +56 -0
- KTeleBERT/get_chinese_ref.py +454 -0
- KTeleBERT/main.py +851 -0
- KTeleBERT/model/HWBert.py +146 -0
- KTeleBERT/model/KE_model.py +451 -0
- KTeleBERT/model/Numeric.py +218 -0
- KTeleBERT/model/OD_model.py +74 -0
- KTeleBERT/model/Tool_model.py +34 -0
- KTeleBERT/model/__init__.py +26 -0
- KTeleBERT/model/__pycache__/HWBert.cpython-38.pyc +0 -0
- KTeleBERT/model/__pycache__/KE_model.cpython-38.pyc +0 -0
- KTeleBERT/model/__pycache__/Numeric.cpython-38.pyc +0 -0
- KTeleBERT/model/__pycache__/OD_model.cpython-38.pyc +0 -0
- KTeleBERT/model/__pycache__/Tool_model.cpython-38.pyc +0 -0
- KTeleBERT/model/__pycache__/__init__.cpython-38.pyc +0 -0
- KTeleBERT/model/bert/__init__.py +201 -0
- KTeleBERT/model/bert/__pycache__/__init__.cpython-38.pyc +0 -0
- KTeleBERT/model/bert/__pycache__/configuration_bert.cpython-38.pyc +0 -0
- KTeleBERT/model/bert/__pycache__/modeling_bert.cpython-38.pyc +0 -0
- KTeleBERT/model/bert/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
- KTeleBERT/model/bert/configuration_bert.py +191 -0
- KTeleBERT/model/bert/modeling_bert.py +2010 -0
- KTeleBERT/model/bert/tokenization_bert.py +574 -0
- KTeleBERT/requirements.txt +10 -0
- KTeleBERT/run.sh +35 -0
- KTeleBERT/run_get_ref.sh +22 -0
- KTeleBERT/special_token_pre_emb.py +119 -0
- KTeleBERT/src/__init__.py +1 -0
- KTeleBERT/src/__pycache__/__init__.cpython-38.pyc +0 -0
- KTeleBERT/src/__pycache__/data.cpython-38.pyc +0 -0
- KTeleBERT/src/__pycache__/distributed_utils.cpython-38.pyc +0 -0
- KTeleBERT/src/__pycache__/utils.cpython-38.pyc +0 -0
- KTeleBERT/src/data.py +651 -0
- KTeleBERT/src/distributed_utils.py +79 -0
- KTeleBERT/src/utils.py +374 -0
- KTeleBERT/test.sh +12 -0
- KTeleBERT/torchlight/__init__.py +20 -0
- KTeleBERT/torchlight/__pycache__/__init__.cpython-38.pyc +0 -0
- KTeleBERT/torchlight/__pycache__/logger.cpython-38.pyc +0 -0
- KTeleBERT/torchlight/__pycache__/metric.cpython-38.pyc +0 -0
- KTeleBERT/torchlight/__pycache__/module.cpython-38.pyc +0 -0
- KTeleBERT/torchlight/__pycache__/utils.cpython-38.pyc +0 -0
- KTeleBERT/torchlight/__pycache__/vocab.cpython-38.pyc +0 -0
- KTeleBERT/torchlight/logger.py +147 -0
- KTeleBERT/torchlight/metric.py +121 -0
- KTeleBERT/torchlight/module.py +133 -0
- KTeleBERT/torchlight/utils.py +195 -0
- KTeleBERT/torchlight/vocab.py +137 -0
KTeleBERT/__pycache__/config.cpython-38.pyc
ADDED
Binary file (7.01 kB). View file
|
|
KTeleBERT/config.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
from easydict import EasyDict as edict
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
|
9 |
+
LAYER_MAPPING = {
|
10 |
+
0: 'od_layer_0',
|
11 |
+
1: 'od_layer_1',
|
12 |
+
2: 'od_layer_2',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class cfg():
|
17 |
+
def __init__(self):
|
18 |
+
self.this_dir = osp.dirname(__file__)
|
19 |
+
# change
|
20 |
+
self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', ''))
|
21 |
+
|
22 |
+
# TODO: add some static variable (The frequency of change is low)
|
23 |
+
|
24 |
+
def get_args(self):
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
# ------------ base ------------
|
27 |
+
parser.add_argument('--train_strategy', default=1, type=int)
|
28 |
+
parser.add_argument('--batch_size', default=64, type=int)
|
29 |
+
parser.add_argument('--batch_size_ke', default=14, type=int)
|
30 |
+
parser.add_argument('--batch_size_od', default=8, type=int)
|
31 |
+
parser.add_argument('--batch_size_ad', default=32, type=int)
|
32 |
+
|
33 |
+
parser.add_argument('--epoch', default=15, type=int)
|
34 |
+
parser.add_argument("--save_model", default=1, type=int, choices=[0, 1])
|
35 |
+
# 用transformer的 save_pretrain 方式保存
|
36 |
+
parser.add_argument("--save_pretrain", default=0, type=int, choices=[0, 1])
|
37 |
+
parser.add_argument("--from_pretrain", default=0, type=int, choices=[0, 1])
|
38 |
+
|
39 |
+
# torthlight
|
40 |
+
parser.add_argument("--no_tensorboard", default=False, action="store_true")
|
41 |
+
parser.add_argument("--exp_name", default="huawei_exp", type=str, help="Experiment name")
|
42 |
+
parser.add_argument("--dump_path", default="dump/", type=str, help="Experiment dump path")
|
43 |
+
parser.add_argument("--exp_id", default="ke256_raekt_ernie2_bs20_p3_c3_5e-6", type=str, help="Experiment ID")
|
44 |
+
# or 3407
|
45 |
+
parser.add_argument("--random_seed", default=42, type=int)
|
46 |
+
# 数据参数
|
47 |
+
parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path")
|
48 |
+
parser.add_argument('--train_ratio', default=1, type=float, help='ratio for train/test')
|
49 |
+
parser.add_argument("--seq_data_name", default='Seq_data_base', type=str, help="seq_data 名字")
|
50 |
+
parser.add_argument("--kg_data_name", default='KG_data_base_rule', type=str, help="kg_data 名字")
|
51 |
+
parser.add_argument("--order_data_name", default='event_order_data', type=str, help="order_data 名字")
|
52 |
+
# TODO: add some dynamic variable
|
53 |
+
parser.add_argument("--model_name", default="MacBert", type=str, help="model name")
|
54 |
+
|
55 |
+
# ------------ 训练阶段 ------------
|
56 |
+
parser.add_argument("--scheduler", default="cos", type=str, choices=["linear", "cos"])
|
57 |
+
parser.add_argument("--optim", default="adamw", type=str)
|
58 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float)
|
59 |
+
parser.add_argument('--workers', type=int, default=8)
|
60 |
+
parser.add_argument('--accumulation_steps', type=int, default=6)
|
61 |
+
parser.add_argument('--accumulation_steps_ke', type=int, default=6)
|
62 |
+
parser.add_argument('--accumulation_steps_ad', type=int, default=6)
|
63 |
+
parser.add_argument('--accumulation_steps_od', type=int, default=6)
|
64 |
+
parser.add_argument("--train_together", default=0, type=int)
|
65 |
+
|
66 |
+
# 3e-5
|
67 |
+
parser.add_argument('--lr', type=float, default=1e-5)
|
68 |
+
# 逐层学习率衰减
|
69 |
+
parser.add_argument("--LLRD", default=0, type=int, choices=[0, 1])
|
70 |
+
parser.add_argument('--weight_decay', type=float, default=0.01)
|
71 |
+
parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
|
72 |
+
parser.add_argument('--scheduler_steps', type=int, default=None,
|
73 |
+
help='total number of step for the scheduler, if None then scheduler_total_step = total_step')
|
74 |
+
parser.add_argument('--eval_step', default=100, type=int, help='evaluate each n step')
|
75 |
+
|
76 |
+
# ------------ PLM ------------
|
77 |
+
parser.add_argument('--maxlength', type=int, default=200)
|
78 |
+
parser.add_argument('--mlm_probability', type=float, default=0.15)
|
79 |
+
parser.add_argument('--final_mlm_probability', type=float, default=0.4)
|
80 |
+
parser.add_argument('--mlm_probability_increase', type=str, default="curve", choices=["linear", "curve"])
|
81 |
+
parser.add_argument("--mask_stratege", default="rand", type=str, choices=["rand", "wwm", "domain"])
|
82 |
+
# 前n个epoch 用rand,后面用wwm. multi-stage knowledge masking strategy
|
83 |
+
parser.add_argument("--ernie_stratege", default=-1, type=int)
|
84 |
+
# 用mlm任务进行训练,默认使用chinese_ref且添加新的special word
|
85 |
+
parser.add_argument("--use_mlm_task", default=1, type=int, choices=[0, 1])
|
86 |
+
# 添加新的special word
|
87 |
+
parser.add_argument("--add_special_word", default=1, type=int, choices=[0, 1])
|
88 |
+
# freeze
|
89 |
+
parser.add_argument("--freeze_layer", default=0, type=int, choices=[0, 1, 2, 3, 4])
|
90 |
+
# 是否mask 特殊token
|
91 |
+
parser.add_argument("--special_token_mask", default=0, type=int, choices=[0, 1])
|
92 |
+
parser.add_argument("--emb_init", default=1, type=int, choices=[0, 1])
|
93 |
+
parser.add_argument("--cls_head_init", default=1, type=int, choices=[0, 1])
|
94 |
+
# 是否使用自适应权重
|
95 |
+
parser.add_argument("--use_awl", default=1, type=int, choices=[0, 1])
|
96 |
+
parser.add_argument("--mask_loss_scale", default=1.0, type=float)
|
97 |
+
|
98 |
+
# ------------ KGE ------------
|
99 |
+
parser.add_argument('--ke_norm', type=int, default=1)
|
100 |
+
parser.add_argument('--ke_dim', type=int, default=768)
|
101 |
+
parser.add_argument('--ke_margin', type=float, default=1.0)
|
102 |
+
parser.add_argument('--neg_num', type=int, default=10)
|
103 |
+
parser.add_argument('--adv_temp', type=float, default=1.0, help='The temperature of sampling in self-adversarial negative sampling.')
|
104 |
+
# 5e-4
|
105 |
+
parser.add_argument('--ke_lr', type=float, default=3e-5)
|
106 |
+
parser.add_argument('--only_ke_loss', type=int, default=0)
|
107 |
+
|
108 |
+
# ------------ 数值embedding相关 ------------
|
109 |
+
parser.add_argument('--use_NumEmb', type=int, default=1)
|
110 |
+
parser.add_argument("--contrastive_loss", default=1, type=int, choices=[0, 1])
|
111 |
+
parser.add_argument("--l_layers", default=2, type=int)
|
112 |
+
parser.add_argument('--use_kpi_loss', type=int, default=1)
|
113 |
+
|
114 |
+
# ------------ 测试阶段 ------------
|
115 |
+
parser.add_argument("--only_test", default=0, type=int, choices=[0, 1])
|
116 |
+
parser.add_argument("--mask_test", default=0, type=int, choices=[0, 1])
|
117 |
+
parser.add_argument("--embed_gen", default=0, type=int, choices=[0, 1])
|
118 |
+
parser.add_argument("--ke_test", default=0, type=int, choices=[0, 1])
|
119 |
+
# -1: 测全集
|
120 |
+
parser.add_argument("--ke_test_num", default=-1, type=int)
|
121 |
+
parser.add_argument("--path_gen", default="", type=str)
|
122 |
+
|
123 |
+
# ------------ 时序阶段 ------------
|
124 |
+
# 1:预训练
|
125 |
+
# 2:时序 finetune
|
126 |
+
# 3. 异常检测 finetune + 时序, 且是迭代的
|
127 |
+
# 是否加载od模型
|
128 |
+
parser.add_argument("--order_load", default=0, type=int)
|
129 |
+
parser.add_argument("--order_num", default=2, type=int)
|
130 |
+
parser.add_argument("--od_type", default='linear_cat', type=str, choices=['linear_cat', 'vertical_attention'])
|
131 |
+
parser.add_argument("--eps", default=0.2, type=float, help='label smoothing..')
|
132 |
+
parser.add_argument("--num_od_layer", default=0, type=int)
|
133 |
+
parser.add_argument("--plm_emb_type", default='cls', type=str, choices=['cls', 'last_avg'])
|
134 |
+
parser.add_argument("--order_test_name", default='', type=str)
|
135 |
+
parser.add_argument("--order_threshold", default=0.5, type=float)
|
136 |
+
# ------------ 并行训练 ------------
|
137 |
+
# 是否并行
|
138 |
+
parser.add_argument('--rank', type=int, default=0, help='rank to dist')
|
139 |
+
parser.add_argument('--dist', type=int, default=0, help='whether to dist')
|
140 |
+
# 不要改该参数,系统会自动分配
|
141 |
+
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
|
142 |
+
# 开启的进程数(注意不是线程),不用设置该参数,会根据nproc_per_node自动设置
|
143 |
+
parser.add_argument('--world-size', default=4, type=int,
|
144 |
+
help='number of distributed processes')
|
145 |
+
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
|
146 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
147 |
+
self.cfg = parser.parse_args()
|
148 |
+
|
149 |
+
def update_train_configs(self):
|
150 |
+
# add some constraint for parameters
|
151 |
+
# e.g. cannot save and test at the same time
|
152 |
+
# 修正默认参数
|
153 |
+
# TODO: 测试逻辑有问题需要修改
|
154 |
+
if len(self.cfg.order_test_name) > 0:
|
155 |
+
self.cfg.save_model = 0
|
156 |
+
if len(self.cfg.order_test_name) == 0:
|
157 |
+
self.cfg.train_ratio = min(0.8, self.cfg.train_ratio)
|
158 |
+
# 自适应载入文件名
|
159 |
+
else:
|
160 |
+
print("od test ... ")
|
161 |
+
self.cfg.train_strategy == 5
|
162 |
+
self.cfg.plm_emb_type = 'last_avg' if 'last_avg' in self.cfg.model_name else 'cls'
|
163 |
+
for key in LAYER_MAPPING.keys():
|
164 |
+
if LAYER_MAPPING[key] in self.cfg.model_name:
|
165 |
+
self.cfg.num_od_layer = key
|
166 |
+
self.cfg.order_test_name = osp.join('downstream_task', f'{self.cfg.order_test_name}')
|
167 |
+
|
168 |
+
if self.cfg.mask_test or self.cfg.embed_gen or self.cfg.ke_test or len(self.cfg.order_test_name) > 0:
|
169 |
+
assert len(self.cfg.model_name) > 0
|
170 |
+
self.cfg.only_test = 1
|
171 |
+
if self.cfg.only_test == 1:
|
172 |
+
self.save_model = 0
|
173 |
+
self.save_pretrain = 0
|
174 |
+
|
175 |
+
# TODO: update some dynamic variable
|
176 |
+
self.cfg.data_root = self.data_root
|
177 |
+
self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)
|
178 |
+
self.cfg.plm_path = osp.join(self.data_root, 'transformer')
|
179 |
+
self.cfg.dump_path = osp.join(self.cfg.data_path, self.cfg.dump_path)
|
180 |
+
# bs 控制尽量在32
|
181 |
+
|
182 |
+
# 自适应权重的数量
|
183 |
+
self.cfg.awl_num = 1
|
184 |
+
# ------------ 数值embedding相关 ------------
|
185 |
+
self.cfg.hidden_size = 768
|
186 |
+
self.cfg.num_attention_heads = 8
|
187 |
+
self.cfg.hidden_dropout_prob = 0.1
|
188 |
+
self.cfg.num_kpi = 304
|
189 |
+
self.cfg.specail_emb_path = None
|
190 |
+
if self.cfg.emb_init:
|
191 |
+
self.cfg.specail_emb_path = osp.join(self.cfg.data_path, 'added_vocab_embedding.pt')
|
192 |
+
|
193 |
+
# ------------- 多任务学习相关 -------------
|
194 |
+
# 四个阶段
|
195 |
+
self.cfg.mask_epoch, self.cfg.ke_epoch, self.cfg.ad_epoch, self.cfg.od_epoch = None, None, None, None
|
196 |
+
# 触发多任务 学习
|
197 |
+
if self.cfg.train_strategy > 1:
|
198 |
+
self.cfg.mask_epoch = [0, 1, 1, 1, 0]
|
199 |
+
self.cfg.ke_epoch = [4, 3, 2, 2, 0]
|
200 |
+
if self.cfg.only_ke_loss:
|
201 |
+
self.cfg.mask_epoch = [0, 0, 0, 0, 0]
|
202 |
+
self.cfg.epoch = sum(self.cfg.mask_epoch) + sum(self.cfg.ke_epoch)
|
203 |
+
if self.cfg.train_strategy > 2:
|
204 |
+
self.cfg.ad_epoch = [0, 6, 3, 1, 0]
|
205 |
+
self.cfg.epoch += sum(self.cfg.ad_epoch)
|
206 |
+
if self.cfg.train_strategy > 3 and not self.cfg.only_ke_loss:
|
207 |
+
self.cfg.od_epoch = [0, 0, 9, 1, 0]
|
208 |
+
# self.cfg.mask_epoch[3] = 1
|
209 |
+
self.cfg.epoch += sum(self.cfg.od_epoch)
|
210 |
+
self.cfg.epoch_matrix = []
|
211 |
+
for epochs in [self.cfg.mask_epoch, self.cfg.ke_epoch, self.cfg.ad_epoch, self.cfg.od_epoch]:
|
212 |
+
if epochs is not None:
|
213 |
+
self.cfg.epoch_matrix.append(epochs)
|
214 |
+
if self.cfg.train_together:
|
215 |
+
# loss 直接相加,训练epoch就是mask的epoch
|
216 |
+
self.cfg.epoch = sum(self.cfg.mask_epoch)
|
217 |
+
self.cfg.batch_size = int((self.cfg.batch_size - 16) / self.cfg.train_strategy)
|
218 |
+
self.cfg.batch_size_ke = int(self.cfg.batch_size_ke / self.cfg.train_strategy) - 2
|
219 |
+
self.cfg.batch_size_ad = int(self.cfg.batch_size_ad / self.cfg.train_strategy) - 1
|
220 |
+
self.cfg.batch_size_od = int(self.cfg.batch_size_od / self.cfg.train_strategy) - 1
|
221 |
+
self.cfg.accumulation_steps = (self.cfg.accumulation_steps - 1) * self.cfg.train_strategy
|
222 |
+
|
223 |
+
self.cfg.neg_num = max(min(self.cfg.neg_num, self.cfg.batch_size_ke - 3), 1)
|
224 |
+
|
225 |
+
self.cfg.accumulation_steps_dict = {0: self.cfg.accumulation_steps, 1: self.cfg.accumulation_steps_ke, 2: self.cfg.accumulation_steps_ad, 3: self.cfg.accumulation_steps_od}
|
226 |
+
|
227 |
+
# 使用数值embedding也必须添加新词因为位置信息和tokenizer绑定
|
228 |
+
if self.cfg.use_mlm_task or self.cfg.use_NumEmb:
|
229 |
+
assert self.cfg.add_special_word == 1
|
230 |
+
|
231 |
+
if self.cfg.use_NumEmb:
|
232 |
+
self.cfg.awl_num += 1
|
233 |
+
|
234 |
+
return self.cfg
|
KTeleBERT/data_trans.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import pdb
|
7 |
+
import json
|
8 |
+
|
9 |
+
'''
|
10 |
+
把数据合并
|
11 |
+
同时抽取一部分需要的数据出来
|
12 |
+
'''
|
13 |
+
|
14 |
+
this_dir = osp.dirname(__file__)
|
15 |
+
|
16 |
+
data_root = osp.abspath(osp.join(this_dir, '..', '..', 'data', ''))
|
17 |
+
|
18 |
+
data_path = "huawei"
|
19 |
+
data_path = osp.join(data_root, data_path)
|
20 |
+
|
21 |
+
|
22 |
+
with open(osp.join(data_path, 'product_corpus.json'), "r") as f:
|
23 |
+
data_doc = json.load(f)
|
24 |
+
|
25 |
+
with open(osp.join(data_path, '831_alarm_serialize.json'), "r") as f:
|
26 |
+
data_alarm = json.load(f)
|
27 |
+
# kpi_info.json
|
28 |
+
with open(osp.join(data_path, '917_kpi_serialize_50_mn.json'), "r") as f:
|
29 |
+
data_kpi = json.load(f)
|
30 |
+
|
31 |
+
|
32 |
+
# 实体的序列化
|
33 |
+
with open(osp.join(data_path, '5GC_KB/database_entity_serialize.json'), "r") as f:
|
34 |
+
data_entity = json.load(f)
|
35 |
+
|
36 |
+
random.shuffle(data_kpi)
|
37 |
+
random.shuffle(data_doc)
|
38 |
+
random.shuffle(data_alarm)
|
39 |
+
random.shuffle(data_entity)
|
40 |
+
data = data_alarm + data_kpi + data_entity + data_doc
|
41 |
+
random.shuffle(data)
|
42 |
+
|
43 |
+
# 241527
|
44 |
+
pdb.set_trace()
|
45 |
+
with open(osp.join(data_path, 'Seq_data_large.json'), "w") as fp:
|
46 |
+
json.dump(data, fp, ensure_ascii=False)
|
47 |
+
|
48 |
+
|
49 |
+
# 三元组
|
50 |
+
with open(osp.join(data_path, '5GC_KB/database_triples.json'), "r") as f:
|
51 |
+
data = json.load(f)
|
52 |
+
random.shuffle(data)
|
53 |
+
|
54 |
+
|
55 |
+
with open(osp.join(data_path, 'KG_data_base.json'), "w") as fp:
|
56 |
+
json.dump(data, fp, ensure_ascii=False)
|
KTeleBERT/get_chinese_ref.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
from easydict import EasyDict as edict
|
6 |
+
import argparse
|
7 |
+
import pdb
|
8 |
+
import json
|
9 |
+
from model import BertTokenizer
|
10 |
+
from collections import Counter
|
11 |
+
from ltp import LTP
|
12 |
+
from tqdm import tqdm
|
13 |
+
from src.utils import add_special_token
|
14 |
+
from functools import reduce
|
15 |
+
from time import time
|
16 |
+
from numpy import mean
|
17 |
+
import math
|
18 |
+
|
19 |
+
from src.utils import Loss_log, time_trans
|
20 |
+
from collections import defaultdict
|
21 |
+
|
22 |
+
|
23 |
+
class cfg():
|
24 |
+
def __init__(self):
|
25 |
+
self.this_dir = osp.dirname(__file__)
|
26 |
+
# change
|
27 |
+
self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', ''))
|
28 |
+
|
29 |
+
def get_args(self):
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
# seq_data_name = "Seq_data_tiny_831"
|
32 |
+
parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path")
|
33 |
+
# TODO: freq 可以考虑 150
|
34 |
+
parser.add_argument("--freq", default=50, type=int, help="出现多少次的词认为是重要的")
|
35 |
+
parser.add_argument("--batch_size", default=100, type=int, help="分词的batch size")
|
36 |
+
parser.add_argument("--seq_data_name", default='Seq_data_large', type=str, help="seq_data 名字")
|
37 |
+
parser.add_argument("--deal_numeric", default=0, type=int, help="是否处理数值数据")
|
38 |
+
|
39 |
+
parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件")
|
40 |
+
self.cfg = parser.parse_args()
|
41 |
+
|
42 |
+
def update_train_configs(self):
|
43 |
+
# TODO: update some dynamic variable
|
44 |
+
self.cfg.data_root = self.data_root
|
45 |
+
self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)
|
46 |
+
|
47 |
+
return self.cfg
|
48 |
+
|
49 |
+
|
50 |
+
def refresh_data(ref, freq, special_token):
|
51 |
+
'''
|
52 |
+
功能:在自定义的special token基础上基于最小出现频率得到更多新词分词系统的参考,作为wwm基础
|
53 |
+
输入:
|
54 |
+
freq: 在(37万)语义词典中的最小出现频率(空格为分词)
|
55 |
+
special_token: 前面手工定义的特殊token(可能存在交集)
|
56 |
+
输出:
|
57 |
+
add_words:在定义的最小出现频率基础上筛选出来的新词
|
58 |
+
'''
|
59 |
+
# 经常出现的sub token
|
60 |
+
seq_sub_data = [line.split() for line in ref]
|
61 |
+
all_data = []
|
62 |
+
for data in seq_sub_data:
|
63 |
+
all_data.extend(data)
|
64 |
+
sub_word_times = dict(Counter(all_data))
|
65 |
+
asub_word_time_order = sorted(sub_word_times.items(), key=lambda x: x[1], reverse=True)
|
66 |
+
# ('LST', 1218), ('RMV', 851), ('DSP', 821), ('ADD', 820), ('MOD', 590), ('SET', 406), ('AWS', 122)
|
67 |
+
# ADD、ACT、ALM-XXX、DEL、DSP、LST
|
68 |
+
add_words = []
|
69 |
+
|
70 |
+
for i in asub_word_time_order:
|
71 |
+
# 把出现频率很高的词加进来
|
72 |
+
if i[1] >= freq and len(i[0]) > 1 and len(i[0]) < 20 and not str.isdigit(i[0]):
|
73 |
+
add_words.append(i[0])
|
74 |
+
add_words.extend(special_token)
|
75 |
+
# 卡100阈值时是935个特殊token
|
76 |
+
print(f"[{len(add_words)}] special words will be added with frequency [{freq}]!")
|
77 |
+
return add_words
|
78 |
+
|
79 |
+
|
80 |
+
def cws(seq_data, add_words, batch_size):
|
81 |
+
'''
|
82 |
+
功能:所有序列数据的输入转换成分词之后的结果
|
83 |
+
输入:
|
84 |
+
seq_data:所有序列数据输入 e.g.['KPI异常下降', 'KPI异常上升']
|
85 |
+
add_words:添加的special words
|
86 |
+
batch_size:每次分多少句
|
87 |
+
输出:
|
88 |
+
all_segment:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']]
|
89 |
+
data_size:输入/输出的序列数量(e.g. 2)
|
90 |
+
'''
|
91 |
+
# seq_data = seq_data.cuda()
|
92 |
+
print(f"loading...")
|
93 |
+
ltp = LTP("LTP/base2") # 默认加载 base2 模型
|
94 |
+
# ltp = LTP()
|
95 |
+
print(f"begin adding words ...")
|
96 |
+
# ltp.add_words(words=add_words, max_window=5) #4.1.5
|
97 |
+
ltp.add_words(words=add_words) # 4.2.8
|
98 |
+
ltp.to("cuda")
|
99 |
+
# for word in add_words:
|
100 |
+
# ltp.add_word(word)
|
101 |
+
print(f"{len(add_words)} special words are added!")
|
102 |
+
|
103 |
+
#
|
104 |
+
# for data in seq_data:
|
105 |
+
# output = ltp.pipeline([data], tasks=["cws"])
|
106 |
+
data_size = len(seq_data)
|
107 |
+
seq_data_cws = []
|
108 |
+
size = int(data_size / batch_size) + 1
|
109 |
+
b = 0
|
110 |
+
e = b + batch_size
|
111 |
+
# pdb.set_trace()
|
112 |
+
|
113 |
+
log = Loss_log()
|
114 |
+
|
115 |
+
with tqdm(total=size) as _tqdm:
|
116 |
+
# pdb.set_trace()
|
117 |
+
# log.time_init()
|
118 |
+
# pdb.set_trace()
|
119 |
+
error_data = []
|
120 |
+
for i in range(size):
|
121 |
+
|
122 |
+
output = []
|
123 |
+
try:
|
124 |
+
_output = ltp.pipeline(seq_data[b:e], tasks=["cws"])
|
125 |
+
for data in _output.cws:
|
126 |
+
try:
|
127 |
+
data_out = ltp.pipeline(data, tasks=["cws"])
|
128 |
+
# data_out_ = reduce(lambda x, y: x.extend(y) or x, data_out.cws)
|
129 |
+
data_out_ = []
|
130 |
+
for i in data_out.cws:
|
131 |
+
data_out_.extend([k.strip() for k in i])
|
132 |
+
output.append(data_out_)
|
133 |
+
except:
|
134 |
+
print(f"二阶段分词出错!范围是:[{b}]-[{e}]")
|
135 |
+
error_data.append(data)
|
136 |
+
|
137 |
+
# pdb.set_trace()
|
138 |
+
except:
|
139 |
+
print(f"第一阶段分词出错!范围是:[{b}]-[{e}]")
|
140 |
+
error_data.append(f"第一阶段分词出错!范围是:[{b}]-[{e}]")
|
141 |
+
# continue
|
142 |
+
seq_data_cws.extend(output)
|
143 |
+
b = e
|
144 |
+
e += batch_size
|
145 |
+
|
146 |
+
# 时间统计
|
147 |
+
if e >= data_size:
|
148 |
+
if b >= data_size:
|
149 |
+
break
|
150 |
+
e = data_size
|
151 |
+
_tqdm.set_description(f'from {b} to {e}:')
|
152 |
+
_tqdm.update(1)
|
153 |
+
|
154 |
+
print(f"过滤了{data_size - len(seq_data_cws)}个句子")
|
155 |
+
|
156 |
+
return seq_data_cws, data_size, error_data
|
157 |
+
|
158 |
+
|
159 |
+
def ltp_debug(ltp, op):
|
160 |
+
output = []
|
161 |
+
for data in op:
|
162 |
+
data_out = ltp.pipeline(data, tasks=["cws"])
|
163 |
+
# data_out_ = reduce(lambda x, y: x.extend(y) or x, data_out.cws)
|
164 |
+
data_out_ = []
|
165 |
+
for i in data_out.cws:
|
166 |
+
# 保留空格的话需要手动去除空格
|
167 |
+
data_out_.append(i[0].strip())
|
168 |
+
# 之前没有空格
|
169 |
+
# data_out_.extend(i)
|
170 |
+
output.append(data_out_)
|
171 |
+
return output
|
172 |
+
|
173 |
+
|
174 |
+
def deal_sub_words(subwords, special_token):
|
175 |
+
'''
|
176 |
+
功能:把每个word的整体内,非首字符的部分加上 '##' 前缀, special_token 不应该被mask
|
177 |
+
'''
|
178 |
+
for i in range(len(subwords)):
|
179 |
+
if i == 0:
|
180 |
+
continue
|
181 |
+
if subwords[i] in special_token:
|
182 |
+
continue
|
183 |
+
if subwords[i].startswith("##"):
|
184 |
+
continue
|
185 |
+
|
186 |
+
subwords[i] = "##" + subwords[i]
|
187 |
+
return subwords
|
188 |
+
|
189 |
+
|
190 |
+
def generate_chinese_ref(seq_data_cws, special_token, deal_numeric, kpi_dic):
|
191 |
+
'''
|
192 |
+
输入:
|
193 |
+
seq_data_cws:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']]
|
194 |
+
special_token:不应该被mask ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '|']
|
195 |
+
data_size:数据量 e.g. 2
|
196 |
+
输出:
|
197 |
+
ww_return (whole word return):打标之后的chinese ref e.g. [['KPI', '异','##常', '下', '##降'], ['KPI', '异', '##常', '上', '##升']]
|
198 |
+
'''
|
199 |
+
# 定义全局set和逆字典统计哪些KPI最后没有被涉及
|
200 |
+
data_size = len(seq_data_cws)
|
201 |
+
kpi_static_set = set()
|
202 |
+
rev_kpi_dic = dict(zip(kpi_dic.values(), kpi_dic.keys()))
|
203 |
+
max_len = 0
|
204 |
+
sten_that_over_maxl = []
|
205 |
+
with tqdm(total=data_size) as _tqdm:
|
206 |
+
ww_return = []
|
207 |
+
ww_list = []
|
208 |
+
kpi_info = []
|
209 |
+
not_in_KPI = defaultdict(int)
|
210 |
+
for i in range(data_size):
|
211 |
+
_tqdm.set_description(f'checking...[{i}/{data_size}] max len: [{max_len}]')
|
212 |
+
orig = tokenizer.tokenize(" ".join(seq_data_cws[i]))
|
213 |
+
|
214 |
+
if deal_numeric:
|
215 |
+
# 得到元组信息,前两位是KPI下标范围
|
216 |
+
_kpi_info, kpi_type_list = extract_kpi(orig, kpi_dic, not_in_KPI)
|
217 |
+
kpi_info.append(_kpi_info)
|
218 |
+
kpi_static_set.update(kpi_type_list)
|
219 |
+
|
220 |
+
sub_total = []
|
221 |
+
ww_seq_tmp = []
|
222 |
+
ww_tmp = []
|
223 |
+
for sub_data in seq_data_cws[i]:
|
224 |
+
sub = tokenizer.tokenize(sub_data)
|
225 |
+
sub_total.extend(sub)
|
226 |
+
# 在whole word 里面添加#号
|
227 |
+
# 输入: ['异', '常']
|
228 |
+
ref_token = deal_sub_words(sub, special_token)
|
229 |
+
# 输出: ['异', '##常']
|
230 |
+
ww_seq_tmp.extend(ref_token)
|
231 |
+
ww_tmp.append(ref_token)
|
232 |
+
|
233 |
+
if sub_total != orig:
|
234 |
+
print("error in match... ")
|
235 |
+
if len(orig) > 512:
|
236 |
+
print("the lenth is over the max lenth")
|
237 |
+
pdb.set_trace()
|
238 |
+
|
239 |
+
# 变成[[...],[...],[...], ...]
|
240 |
+
# ww_return.append(ww_tmp)
|
241 |
+
sz_ww_seq = len(ww_seq_tmp)
|
242 |
+
# 求最大长度
|
243 |
+
max_len = sz_ww_seq if sz_ww_seq > max_len else max_len
|
244 |
+
if sz_ww_seq > 500:
|
245 |
+
sten_that_over_maxl.append((ww_seq_tmp, sz_ww_seq))
|
246 |
+
|
247 |
+
assert len(sub_total) == sz_ww_seq
|
248 |
+
ww_return.append(ww_seq_tmp)
|
249 |
+
ww_list.append(ww_tmp)
|
250 |
+
# pdb.set_trace()
|
251 |
+
_tqdm.update(1)
|
252 |
+
# pdb.set_trace()
|
253 |
+
if deal_numeric:
|
254 |
+
in_kpi = []
|
255 |
+
# pdb.set_trace()
|
256 |
+
for key in rev_kpi_dic.keys():
|
257 |
+
if key in kpi_static_set:
|
258 |
+
in_kpi.append(rev_kpi_dic[key])
|
259 |
+
if len(in_kpi) < len(rev_kpi_dic):
|
260 |
+
print(f"[{len(in_kpi)}] KPI are covered by data: {in_kpi}")
|
261 |
+
print(f" [{len(not_in_KPI)}] KPI无法匹配{not_in_KPI}")
|
262 |
+
else:
|
263 |
+
print("all KPI are covered!")
|
264 |
+
return ww_return, kpi_info, sten_that_over_maxl
|
265 |
+
|
266 |
+
|
267 |
+
def extract_num(seq_data_cws):
|
268 |
+
'''
|
269 |
+
功能:把序列中的数值信息提取出来
|
270 |
+
同时过滤 nan 数值
|
271 |
+
'''
|
272 |
+
num_ref = []
|
273 |
+
seq_data_cws_new = []
|
274 |
+
for j in range(len(seq_data_cws)):
|
275 |
+
num_index = [i for i, x in enumerate(seq_data_cws[j]) if x == '[NUM]']
|
276 |
+
# kpi_score = [float(seq_data_cws[i][index+1]) for index in num_index]
|
277 |
+
kpi_score = []
|
278 |
+
flag = 1
|
279 |
+
for index in num_index:
|
280 |
+
# if math.isnan(tmp):
|
281 |
+
# pdb.set_trace()
|
282 |
+
try:
|
283 |
+
tmp = float(seq_data_cws[j][index + 1])
|
284 |
+
except:
|
285 |
+
# pdb.set_trace()
|
286 |
+
flag = 0
|
287 |
+
continue
|
288 |
+
if math.isnan(tmp):
|
289 |
+
flag = 0
|
290 |
+
else:
|
291 |
+
kpi_score.append(tmp)
|
292 |
+
|
293 |
+
if len(num_index) > 0:
|
294 |
+
for index in reversed(num_index):
|
295 |
+
seq_data_cws[j].pop(index + 1)
|
296 |
+
if flag == 1:
|
297 |
+
num_ref.append(kpi_score)
|
298 |
+
seq_data_cws_new.append(seq_data_cws[j])
|
299 |
+
return seq_data_cws_new, num_ref
|
300 |
+
|
301 |
+
|
302 |
+
def extract_kpi(token_data, kpi_dic, not_in_KPI):
|
303 |
+
'''
|
304 |
+
功能:把序列中的[KPI]下标范围,[NUM]下标提取出来
|
305 |
+
输出格式: [(1,2,4),(5,6,7)]
|
306 |
+
'''
|
307 |
+
kpi_and_num_info = []
|
308 |
+
kpi_type = []
|
309 |
+
kpi_index = [i for i, x in enumerate(token_data) if x.lower() == '[kpi]']
|
310 |
+
num_index = [i for i, x in enumerate(token_data) if x.lower() == '[num]']
|
311 |
+
sz = len(kpi_index)
|
312 |
+
assert sz == len(num_index)
|
313 |
+
for i in range(sz):
|
314 |
+
# (kpi 开始,kpi 结束,NUM token位置)
|
315 |
+
# DONE: 添加KPI的类别
|
316 |
+
kpi_name = ''.join(token_data[kpi_index[i] + 1: num_index[i] - 1])
|
317 |
+
kpi_name_clear = kpi_name.replace('##', '')
|
318 |
+
|
319 |
+
if kpi_name in kpi_dic:
|
320 |
+
kpi_id = int(kpi_dic[kpi_name])
|
321 |
+
elif kpi_name_clear in kpi_dic:
|
322 |
+
kpi_id = int(kpi_dic[kpi_name_clear])
|
323 |
+
elif kpi_name_clear in not_in_KPI:
|
324 |
+
kpi_id = -1
|
325 |
+
not_in_KPI[kpi_name_clear] += 1
|
326 |
+
else:
|
327 |
+
# 只打印一次
|
328 |
+
not_in_KPI[kpi_name_clear] += 1
|
329 |
+
kpi_id = -1
|
330 |
+
# print(f"{kpi_name_clear} not in KPI dict")
|
331 |
+
|
332 |
+
kpi_info = [kpi_index[i] + 1, num_index[i] - 2, num_index[i], kpi_id]
|
333 |
+
kpi_and_num_info.append(kpi_info)
|
334 |
+
kpi_type.append(kpi_id)
|
335 |
+
# pdb.set_trace()
|
336 |
+
|
337 |
+
return kpi_and_num_info, kpi_type
|
338 |
+
|
339 |
+
|
340 |
+
def kpi_combine(kpi_info, num_ref):
|
341 |
+
sz = len(kpi_info)
|
342 |
+
assert sz == len(num_ref)
|
343 |
+
for i in range(sz):
|
344 |
+
for j in range(len(kpi_info[i])):
|
345 |
+
kpi_info[i][j].append(num_ref[i][j])
|
346 |
+
# pdb.set_trace()
|
347 |
+
return kpi_info
|
348 |
+
|
349 |
+
# 所有字母小写
|
350 |
+
|
351 |
+
|
352 |
+
def kpi_lower_update(kpi_dic):
|
353 |
+
new_dic = {}
|
354 |
+
for key in kpi_dic:
|
355 |
+
kk = key.lower().split()
|
356 |
+
kk = ''.join(kk).strip()
|
357 |
+
new_dic[kk] = kpi_dic[key]
|
358 |
+
return new_dic
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == '__main__':
|
362 |
+
'''
|
363 |
+
功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据)
|
364 |
+
'''
|
365 |
+
cfg = cfg()
|
366 |
+
cfg.get_args()
|
367 |
+
cfgs = cfg.update_train_configs()
|
368 |
+
|
369 |
+
# 路径指定
|
370 |
+
domain_file_path = osp.join(cfgs.data_path, 'special_vocab.txt')
|
371 |
+
with open(domain_file_path, encoding="utf-8") as f:
|
372 |
+
ref = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
373 |
+
tokenizer = BertTokenizer.from_pretrained(osp.join(cfgs.data_root, 'transformer', 'MacBert'), do_lower_case=True)
|
374 |
+
seq_data_name = cfgs.seq_data_name
|
375 |
+
with open(osp.join(cfgs.data_path, f'{seq_data_name}.json'), "r") as fp:
|
376 |
+
seq_data = json.load(fp)
|
377 |
+
kpi_dic_name = 'kpi2id'
|
378 |
+
with open(osp.join(cfgs.data_path, f'{kpi_dic_name}.json'), "r") as fp:
|
379 |
+
kpi_dic = json.load(fp)
|
380 |
+
kpi_dic = kpi_lower_update(kpi_dic)
|
381 |
+
# 供测试
|
382 |
+
random.shuffle(seq_data)
|
383 |
+
# seq_data = seq_data[:500]
|
384 |
+
print(f"tokenizer size before: {len(tokenizer)}")
|
385 |
+
tokenizer, special_token, norm_token = add_special_token(tokenizer)
|
386 |
+
special_token = special_token + norm_token
|
387 |
+
|
388 |
+
print(f"tokenizer size after: {len(tokenizer)}")
|
389 |
+
print('------------------------ refresh data --------------------------------')
|
390 |
+
add_words = refresh_data(ref, cfgs.freq, special_token)
|
391 |
+
|
392 |
+
if not cfgs.read_cws:
|
393 |
+
print('------------------------ cws ----------------------------------')
|
394 |
+
seq_data_cws, data_size, error_data = cws(seq_data, add_words, cfgs.batch_size)
|
395 |
+
print(f'batch size is {cfgs.batch_size}')
|
396 |
+
if len(error_data) > 0:
|
397 |
+
with open(osp.join(cfgs.data_path, f'{seq_data_name}_error.json'), "w") as fp:
|
398 |
+
json.dump(error_data, fp, ensure_ascii=False)
|
399 |
+
save_path_cws_orig = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json')
|
400 |
+
print("get the new training data! saving...")
|
401 |
+
with open(save_path_cws_orig, 'w', ) as fp:
|
402 |
+
json.dump(seq_data_cws, fp, ensure_ascii=False)
|
403 |
+
else:
|
404 |
+
print('------------------------ read ----------------------------------')
|
405 |
+
save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json')
|
406 |
+
print("get the new training data!")
|
407 |
+
with open(save_path_cws, 'r', ) as fp:
|
408 |
+
seq_data_cws = json.load(fp)
|
409 |
+
data_size = len(seq_data_cws)
|
410 |
+
|
411 |
+
sz_orig = len(seq_data_cws)
|
412 |
+
if cfgs.deal_numeric:
|
413 |
+
seq_data_cws, num_ref = extract_num(seq_data_cws)
|
414 |
+
print(f"过滤了{sz_orig - len(seq_data_cws)}个无效句子")
|
415 |
+
data_size = len(seq_data_cws)
|
416 |
+
|
417 |
+
print('---------------------- generate chinese ref ------------------------------')
|
418 |
+
chinese_ref, kpi_info, sten_that_over_maxl = generate_chinese_ref(seq_data_cws, special_token, cfgs.deal_numeric, kpi_dic)
|
419 |
+
|
420 |
+
if len(sten_that_over_maxl) > 0:
|
421 |
+
print(f"{len(sten_that_over_maxl)} over the 500 len!")
|
422 |
+
save_path_max = osp.join(cfgs.data_path, f'{seq_data_name}_max_len_500.json')
|
423 |
+
with open(save_path_max, 'w') as fp:
|
424 |
+
json.dump(sten_that_over_maxl, fp, ensure_ascii=False)
|
425 |
+
|
426 |
+
if cfgs.deal_numeric:
|
427 |
+
print("KPI info combine")
|
428 |
+
kpi_ref = kpi_combine(kpi_info, num_ref)
|
429 |
+
# pdb.set_trace()
|
430 |
+
print('------------------------- match finished ------------------------------')
|
431 |
+
|
432 |
+
# 输出最后训练的时候用于做wwm的分词
|
433 |
+
save_path_ref = osp.join(cfgs.data_path, f'{seq_data_name}_chinese_ref.json')
|
434 |
+
with open(save_path_ref, 'w') as fp:
|
435 |
+
json.dump(chinese_ref, fp, ensure_ascii=False)
|
436 |
+
print(f"save chinese_ref done!")
|
437 |
+
|
438 |
+
seq_data_cws_output = []
|
439 |
+
for i in range(data_size):
|
440 |
+
seq = " ".join(seq_data_cws[i])
|
441 |
+
seq_data_cws_output.append(seq)
|
442 |
+
|
443 |
+
save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws.json')
|
444 |
+
print("get the new training data!")
|
445 |
+
with open(save_path_cws, 'w', ) as fp:
|
446 |
+
json.dump(seq_data_cws_output, fp, ensure_ascii=False)
|
447 |
+
|
448 |
+
print("save seq_data_cws done!")
|
449 |
+
|
450 |
+
if cfgs.deal_numeric:
|
451 |
+
kpi_ref_path = osp.join(cfgs.data_path, f'{seq_data_name}_kpi_ref.json')
|
452 |
+
with open(kpi_ref_path, 'w', ) as fp:
|
453 |
+
json.dump(kpi_ref, fp, ensure_ascii=False)
|
454 |
+
print("save num and kpi done!")
|
KTeleBERT/main.py
ADDED
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
from torch.utils.tensorboard import SummaryWriter
|
5 |
+
from torch.utils.data import DataLoader, RandomSampler
|
6 |
+
from torch.cuda.amp import GradScaler, autocast
|
7 |
+
from datetime import datetime
|
8 |
+
from easydict import EasyDict as edict
|
9 |
+
from tqdm import tqdm
|
10 |
+
import pdb
|
11 |
+
import pprint
|
12 |
+
import json
|
13 |
+
import pickle
|
14 |
+
from collections import defaultdict
|
15 |
+
import copy
|
16 |
+
from time import time
|
17 |
+
|
18 |
+
from config import cfg
|
19 |
+
from torchlight import initialize_exp, set_seed, get_dump_path
|
20 |
+
from src.data import load_data, load_data_kg, Collator_base, Collator_kg, SeqDataset, KGDataset, Collator_order, load_order_data
|
21 |
+
from src.utils import set_optim, Loss_log, add_special_token, time_trans
|
22 |
+
from src.distributed_utils import init_distributed_mode, dist_pdb, is_main_process, reduce_value, cleanup
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
from itertools import cycle
|
26 |
+
from model import BertTokenizer, HWBert, KGEModel, OD_model, KE_model
|
27 |
+
import torch.multiprocessing
|
28 |
+
from torch.nn.parallel import DistributedDataParallel
|
29 |
+
|
30 |
+
# 默认用cuda就行
|
31 |
+
|
32 |
+
|
33 |
+
class Runner:
|
34 |
+
def __init__(self, args, writer=None, logger=None, rank=0):
|
35 |
+
self.datapath = edict()
|
36 |
+
self.datapath.log_dir = get_dump_path(args)
|
37 |
+
self.datapath.model_dir = os.path.join(self.datapath.log_dir, 'model')
|
38 |
+
self.rank = rank
|
39 |
+
# init code
|
40 |
+
self.mlm_probability = args.mlm_probability
|
41 |
+
self.args = args
|
42 |
+
self.writer = writer
|
43 |
+
self.logger = logger
|
44 |
+
# 模型选择
|
45 |
+
self.model_list = []
|
46 |
+
self.model = HWBert(self.args)
|
47 |
+
# 数据加载。添加special_token,同时把模型的embedding layer进行resize
|
48 |
+
self.data_init()
|
49 |
+
self.model.cuda()
|
50 |
+
# 模型加载
|
51 |
+
self.od_model, self.ke_model = None, None
|
52 |
+
self.scaler = GradScaler()
|
53 |
+
|
54 |
+
# 只要不是第一种训练策略就有新模型
|
55 |
+
if self.args.train_strategy >= 2:
|
56 |
+
self.ke_model = KE_model(self.args)
|
57 |
+
if self.args.train_strategy >= 3:
|
58 |
+
# TODO: 异常检测
|
59 |
+
pass
|
60 |
+
if self.args.train_strategy >= 4:
|
61 |
+
self.od_model = OD_model(self.args)
|
62 |
+
|
63 |
+
if self.args.model_name not in ['MacBert', 'TeleBert', 'TeleBert2', 'TeleBert3'] and not self.args.from_pretrain:
|
64 |
+
# 如果不存在模型会直接返回None或者原始模型
|
65 |
+
self.model = self._load_model(self.model, self.args.model_name)
|
66 |
+
self.od_model = self._load_model(self.od_model, f"od_{self.args.model_name}")
|
67 |
+
self.ke_model = self._load_model(self.ke_model, f"ke_{self.args.model_name}")
|
68 |
+
# TODO: 异常检测
|
69 |
+
|
70 |
+
# 测试的情况
|
71 |
+
if self.args.only_test:
|
72 |
+
self.dataloader_init(self.seq_test_set)
|
73 |
+
else:
|
74 |
+
# 训练
|
75 |
+
if self.args.ernie_stratege > 0:
|
76 |
+
self.args.mask_stratege = 'rand'
|
77 |
+
# 初始化dataloader
|
78 |
+
self.dataloader_init(self.seq_train_set, self.kg_train_set, self.order_train_set)
|
79 |
+
if self.args.dist:
|
80 |
+
# 并行训练需要权值共享
|
81 |
+
self.model_sync()
|
82 |
+
else:
|
83 |
+
self.model_list = [model for model in [self.model, self.od_model, self.ke_model] if model is not None]
|
84 |
+
|
85 |
+
self.optim_init(self.args)
|
86 |
+
|
87 |
+
def model_sync(self):
|
88 |
+
checkpoint_path = osp.join(self.args.data_path, "tmp", "initial_weights.pt")
|
89 |
+
checkpoint_path_od = osp.join(self.args.data_path, "tmp", "initial_weights_od.pt")
|
90 |
+
checkpoint_path_ke = osp.join(self.args.data_path, "tmp", "initial_weights_ke.pt")
|
91 |
+
if self.rank == 0:
|
92 |
+
torch.save(self.model.state_dict(), checkpoint_path)
|
93 |
+
if self.od_model is not None:
|
94 |
+
torch.save(self.od_model.state_dict(), checkpoint_path_od)
|
95 |
+
if self.ke_model is not None:
|
96 |
+
torch.save(self.ke_model.state_dict(), checkpoint_path_ke)
|
97 |
+
dist.barrier()
|
98 |
+
|
99 |
+
# if self.rank != 0:
|
100 |
+
# 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
|
101 |
+
self.model = self._model_sync(self.model, checkpoint_path)
|
102 |
+
if self.od_model is not None:
|
103 |
+
self.od_model = self._model_sync(self.od_model, checkpoint_path_od)
|
104 |
+
if self.ke_model is not None:
|
105 |
+
self.ke_model = self._model_sync(self.ke_model, checkpoint_path_ke)
|
106 |
+
|
107 |
+
def _model_sync(self, model, checkpoint_path):
|
108 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=self.args.device))
|
109 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(self.args.device)
|
110 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.args.gpu], find_unused_parameters=True)
|
111 |
+
self.model_list.append(model)
|
112 |
+
model = model.module
|
113 |
+
return model
|
114 |
+
|
115 |
+
def optim_init(self, opt, total_step=None, accumulation_step=None):
|
116 |
+
step_per_epoch = len(self.train_dataloader)
|
117 |
+
# 占总step 10% 的warmup_steps
|
118 |
+
opt.total_steps = int(step_per_epoch * opt.epoch) if total_step is None else int(total_step)
|
119 |
+
opt.warmup_steps = int(opt.total_steps * 0.15)
|
120 |
+
|
121 |
+
if self.rank == 0 and total_step is None:
|
122 |
+
self.logger.info(f"warmup_steps: {opt.warmup_steps}")
|
123 |
+
self.logger.info(f"total_steps: {opt.total_steps}")
|
124 |
+
self.logger.info(f"weight_decay: {opt.weight_decay}")
|
125 |
+
|
126 |
+
freeze_part = ['bert.encoder.layer.1.', 'bert.encoder.layer.2.', 'bert.encoder.layer.3.', 'bert.encoder.layer.4.'][:self.args.freeze_layer]
|
127 |
+
self.optimizer, self.scheduler = set_optim(opt, self.model_list, freeze_part, accumulation_step)
|
128 |
+
|
129 |
+
def data_init(self):
|
130 |
+
# 载入数据, 两部分数据包括:载入mask loss部分的数据(序列化的数据) 和 载入triple loss部分的数据(三元组)
|
131 |
+
# train_test_split: 训练集长度
|
132 |
+
self.seq_train_set, self.seq_test_set, self.kg_train_set, self.kg_data = None, None, None, None
|
133 |
+
self.order_train_set, self.order_test_set = None, None
|
134 |
+
|
135 |
+
if self.args.train_strategy >= 1 and self.args.train_strategy <= 4:
|
136 |
+
# 预训练 or multi task pretrain
|
137 |
+
self.seq_train_set, self.seq_test_set, train_test_split = load_data(self.logger, self.args)
|
138 |
+
if self.args.train_strategy >= 2:
|
139 |
+
self.kg_train_set, self.kg_data = load_data_kg(self.logger, self.args)
|
140 |
+
if self.args.train_strategy >= 3:
|
141 |
+
# TODO: 异常检测的数据载入
|
142 |
+
pass
|
143 |
+
if self.args.train_strategy >= 4:
|
144 |
+
self.order_train_set, self.order_test_set, train_test_split = load_order_data(self.logger, self.args)
|
145 |
+
|
146 |
+
if self.args.dist and not self.args.only_test:
|
147 |
+
# 测试不需要并行
|
148 |
+
if self.args.train_strategy >= 1 and self.args.train_strategy <= 4:
|
149 |
+
self.seq_train_sampler = torch.utils.data.distributed.DistributedSampler(self.seq_train_set)
|
150 |
+
if self.args.train_strategy >= 2:
|
151 |
+
self.kg_train_sampler = torch.utils.data.distributed.DistributedSampler(self.kg_train_set)
|
152 |
+
if self.args.train_strategy >= 3:
|
153 |
+
# TODO: 异常检测的数据载入
|
154 |
+
pass
|
155 |
+
if self.args.train_strategy >= 4:
|
156 |
+
self.order_train_sampler = torch.utils.data.distributed.DistributedSampler(self.order_train_set)
|
157 |
+
|
158 |
+
# self.seq_train_batch_sampler = torch.utils.data.BatchSampler(self.seq_train_sampler, self.args.batch_size, drop_last=True)
|
159 |
+
# self.kg_train_batch_sampler = torch.utils.data.BatchSampler(self.kg_train_sampler, int(self.args.batch_size / 4), drop_last=True)
|
160 |
+
|
161 |
+
# Tokenizer 载入
|
162 |
+
model_name = self.args.model_name
|
163 |
+
if self.args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']:
|
164 |
+
self.tokenizer = BertTokenizer.from_pretrained(osp.join(self.args.data_root, 'transformer', model_name), do_lower_case=True)
|
165 |
+
else:
|
166 |
+
if not osp.exists(osp.join(self.args.data_root, 'transformer', self.args.model_name)):
|
167 |
+
model_name = 'MacBert'
|
168 |
+
self.tokenizer = BertTokenizer.from_pretrained(osp.join(self.args.data_root, 'transformer', model_name), do_lower_case=True)
|
169 |
+
|
170 |
+
# 添加special_token,同时把模型的embedding layer进行resize
|
171 |
+
self.special_token = None
|
172 |
+
# 单纯的telebert在测试时不需要特殊embedding
|
173 |
+
if self.args.add_special_word and not (self.args.only_test and self.args.model_name in ['MacBert', 'TeleBert', 'TeleBert2', 'TeleBert3']):
|
174 |
+
# tokenizer, special_token, norm_token
|
175 |
+
# special_token 不应该被MASK
|
176 |
+
self.tokenizer, special_token, _ = add_special_token(self.tokenizer, model=self.model.encoder, rank=self.rank, cache_path=self.args.specail_emb_path)
|
177 |
+
# pdb.set_trace()
|
178 |
+
self.special_token = [token.lower() for token in special_token]
|
179 |
+
|
180 |
+
def _dataloader_dist(self, train_set, train_sampler, batch_size, collator):
|
181 |
+
train_dataloader = DataLoader(
|
182 |
+
train_set,
|
183 |
+
sampler=train_sampler,
|
184 |
+
pin_memory=True,
|
185 |
+
num_workers=self.args.workers,
|
186 |
+
persistent_workers=True,
|
187 |
+
drop_last=True,
|
188 |
+
batch_size=batch_size,
|
189 |
+
collate_fn=collator
|
190 |
+
)
|
191 |
+
return train_dataloader
|
192 |
+
|
193 |
+
def _dataloader(self, train_set, batch_size, collator):
|
194 |
+
train_dataloader = DataLoader(
|
195 |
+
train_set,
|
196 |
+
num_workers=self.args.workers,
|
197 |
+
persistent_workers=True,
|
198 |
+
shuffle=(self.args.only_test == 0),
|
199 |
+
drop_last=(self.args.only_test == 0),
|
200 |
+
batch_size=batch_size,
|
201 |
+
collate_fn=collator
|
202 |
+
)
|
203 |
+
return train_dataloader
|
204 |
+
|
205 |
+
def dataloader_init(self, train_set=None, kg_train_set=None, order_train_set=None):
|
206 |
+
bs = self.args.batch_size
|
207 |
+
bs_ke = self.args.batch_size_ke
|
208 |
+
bs_od = self.args.batch_size_od
|
209 |
+
bs_ad = self.args.batch_size_ad
|
210 |
+
# 分布式
|
211 |
+
if self.args.dist and not self.args.only_test:
|
212 |
+
self.args.workers = min([os.cpu_count(), self.args.batch_size, self.args.workers])
|
213 |
+
# if self.rank == 0:
|
214 |
+
# print(f'Using {self.args.workers} dataloader workers every process')
|
215 |
+
|
216 |
+
if train_set is not None:
|
217 |
+
seq_collator = Collator_base(self.args, tokenizer=self.tokenizer, special_token=self.special_token)
|
218 |
+
self.train_dataloader = self._dataloader_dist(train_set, self.seq_train_sampler, bs, seq_collator)
|
219 |
+
if kg_train_set is not None:
|
220 |
+
kg_collator = Collator_kg(self.args, tokenizer=self.tokenizer, data=self.kg_data)
|
221 |
+
self.train_dataloader_kg = self._dataloader_dist(kg_train_set, self.kg_train_sampler, bs_ke, kg_collator)
|
222 |
+
if order_train_set is not None:
|
223 |
+
order_collator = Collator_order(self.args, tokenizer=self.tokenizer)
|
224 |
+
self.train_dataloader_order = self._dataloader_dist(order_train_set, self.order_train_sampler, bs_od, order_collator)
|
225 |
+
else:
|
226 |
+
if train_set is not None:
|
227 |
+
seq_collator = Collator_base(self.args, tokenizer=self.tokenizer, special_token=self.special_token)
|
228 |
+
self.train_dataloader = self._dataloader(train_set, bs, seq_collator)
|
229 |
+
if kg_train_set is not None:
|
230 |
+
kg_collator = Collator_kg(self.args, tokenizer=self.tokenizer, data=self.kg_data)
|
231 |
+
self.train_dataloader_kg = self._dataloader(kg_train_set, bs_ke, kg_collator)
|
232 |
+
if order_train_set is not None:
|
233 |
+
order_collator = Collator_order(self.args, tokenizer=self.tokenizer)
|
234 |
+
self.train_dataloader_order = self._dataloader(order_train_set, bs_od, order_collator)
|
235 |
+
|
236 |
+
def dist_step(self, task=0):
|
237 |
+
# 分布式训练需要额外step
|
238 |
+
if self.args.dist:
|
239 |
+
if task == 0:
|
240 |
+
self.seq_train_sampler.set_epoch(self.dist_epoch)
|
241 |
+
if task == 1:
|
242 |
+
self.kg_train_sampler.set_epoch(self.dist_epoch)
|
243 |
+
if task == 2:
|
244 |
+
# TODO:异常检测
|
245 |
+
pass
|
246 |
+
if task == 3:
|
247 |
+
self.order_train_sampler.set_epoch(self.dist_epoch)
|
248 |
+
self.dist_epoch += 1
|
249 |
+
|
250 |
+
def mask_rate_update(self, i):
|
251 |
+
# 这种策略是曲线地增加 mask rate
|
252 |
+
if self.args.mlm_probability_increase == "curve":
|
253 |
+
self.args.mlm_probability += (i + 1) * ((self.args.final_mlm_probability - self.args.mlm_probability) / self.args.epoch)
|
254 |
+
# 这种是线性的
|
255 |
+
else:
|
256 |
+
assert self.args.mlm_probability_increase == "linear"
|
257 |
+
self.args.mlm_probability += (self.args.final_mlm_probability - self.mlm_probability) / self.args.epoch
|
258 |
+
|
259 |
+
if self.rank == 0:
|
260 |
+
self.logger.info(f"Moving Mlm_probability in next epoch to: {self.args.mlm_probability*100}%")
|
261 |
+
|
262 |
+
def task_switch(self, training_strategy):
|
263 |
+
# 同时训练或者策略1训练不需要切换任务,epoch也安装初始epoch就行
|
264 |
+
if training_strategy == 1 or self.args.train_together:
|
265 |
+
return (0, 0), None
|
266 |
+
|
267 |
+
# 4 阶段
|
268 |
+
# self.total_epoch -= 1
|
269 |
+
|
270 |
+
for i in range(4):
|
271 |
+
for task in range(training_strategy):
|
272 |
+
if self.args.epoch_matrix[task][i] > 0:
|
273 |
+
self.args.epoch_matrix[task][i] -= 1
|
274 |
+
return (task, i), self.args.epoch_matrix[task][i] + 1
|
275 |
+
|
276 |
+
def run(self):
|
277 |
+
self.loss_log = Loss_log()
|
278 |
+
self.curr_loss = 0.
|
279 |
+
self.lr = self.args.lr
|
280 |
+
self.curr_loss_dic = defaultdict(float)
|
281 |
+
self.curr_kpi_loss_dic = defaultdict(float)
|
282 |
+
self.loss_weight = [1, 1]
|
283 |
+
self.kpi_loss_weight = [1, 1]
|
284 |
+
self.step = 0
|
285 |
+
# 不同task 的累计step
|
286 |
+
self.total_step_sum = 0
|
287 |
+
task_last = 0
|
288 |
+
stage_last = 0
|
289 |
+
self.dist_epoch = 0
|
290 |
+
# 后面可以变成混合训练模式
|
291 |
+
# self.total_epoch = self.args.epoch
|
292 |
+
# --------- train -------------
|
293 |
+
with tqdm(total=self.args.epoch) as _tqdm: # 使用需要的参数对tqdm进行初始化
|
294 |
+
for i in range(self.args.epoch):
|
295 |
+
# 切换Task
|
296 |
+
(task, stage), task_epoch = self.task_switch(self.args.train_strategy)
|
297 |
+
self.dist_step(task)
|
298 |
+
dataloader = self.task_dataloader_choose(task)
|
299 |
+
# 并行
|
300 |
+
if self.args.train_together and self.args.train_strategy > 1:
|
301 |
+
self.dataloader_list = ['#']
|
302 |
+
# 一个list 存下所有需要的dataloader的迭代
|
303 |
+
for t in range(1, self.args.train_strategy):
|
304 |
+
self.dist_step(t)
|
305 |
+
self.dataloader_list.append(iter(self.task_dataloader_choose(t)))
|
306 |
+
|
307 |
+
if task != task_last or stage != stage_last:
|
308 |
+
self.step = 0
|
309 |
+
if self.rank == 0:
|
310 |
+
print(f"switch to task [{task}] in stage [{stage}]...")
|
311 |
+
if stage != stage_last:
|
312 |
+
# 每一个阶段结束保存一次
|
313 |
+
self._save_model(stage=f'_stg{stage_last}')
|
314 |
+
# task 转换状态时需要重新初始化优化器
|
315 |
+
# 并行训练或者单一task (0) 训练不需要切换opti
|
316 |
+
if task_epoch is not None:
|
317 |
+
self.optim_init(self.args, total_step=len(dataloader) * task_epoch, accumulation_step=self.args.accumulation_steps_dict[task])
|
318 |
+
task_last = task
|
319 |
+
stage_last = stage
|
320 |
+
|
321 |
+
# 调整学习阶段
|
322 |
+
if task == 0 and self.args.ernie_stratege > 0 and i >= self.args.ernie_stratege:
|
323 |
+
# 不会再触发第二次
|
324 |
+
self.args.ernie_stratege = 10000000
|
325 |
+
if self.rank == 0:
|
326 |
+
self.logger.info("switch to wwm stratege...")
|
327 |
+
self.args.mask_stratege = 'wwm'
|
328 |
+
|
329 |
+
if self.args.mlm_probability != self.args.final_mlm_probability:
|
330 |
+
# 更新 MASK rate
|
331 |
+
# 初始化训练数据, 可以随epoch切换
|
332 |
+
# 混合训练
|
333 |
+
self.mask_rate_update(i)
|
334 |
+
self.dataloader_init(self.seq_train_set, self.kg_train_set, self.order_train_set)
|
335 |
+
# -------------------------------
|
336 |
+
# 针对task 进行训练
|
337 |
+
self.train(_tqdm, dataloader, task)
|
338 |
+
# -------------------------------
|
339 |
+
_tqdm.update(1)
|
340 |
+
|
341 |
+
# DONE: save or load
|
342 |
+
if self.rank == 0:
|
343 |
+
self.logger.info(f"min loss {self.loss_log.get_min_loss()}")
|
344 |
+
# DONE: save or load
|
345 |
+
if not self.args.only_test and self.args.save_model:
|
346 |
+
self._save_model()
|
347 |
+
|
348 |
+
def task_dataloader_choose(self, task):
|
349 |
+
self.model.train()
|
350 |
+
# 同时训练就用基础dataloader就行
|
351 |
+
if task == 0:
|
352 |
+
dataloader = self.train_dataloader
|
353 |
+
elif task == 1:
|
354 |
+
self.ke_model.train()
|
355 |
+
dataloader = self.train_dataloader_kg
|
356 |
+
elif task == 2:
|
357 |
+
pass
|
358 |
+
elif task == 3:
|
359 |
+
self.od_model.train()
|
360 |
+
dataloader = self.train_dataloader_order
|
361 |
+
return dataloader
|
362 |
+
# one time train
|
363 |
+
|
364 |
+
def loss_output(self, batch, task):
|
365 |
+
# -------- 模型输出 loss --------
|
366 |
+
if task == 0:
|
367 |
+
# 输出
|
368 |
+
_output = self.model(batch)
|
369 |
+
loss = _output['loss']
|
370 |
+
elif task == 1:
|
371 |
+
loss = self.ke_model(batch, self.model)
|
372 |
+
elif task == 2:
|
373 |
+
pass
|
374 |
+
elif task == 3:
|
375 |
+
# TODO: finetune的时候多任务 accumulation_steps 自适应
|
376 |
+
# OD task
|
377 |
+
emb = self.model.cls_embedding(batch[0], tp=self.args.plm_emb_type)
|
378 |
+
loss, loss_dic = self.od_model(emb, batch[1].cuda())
|
379 |
+
order_score = self.od_model.predict(emb)
|
380 |
+
token_right = self.od_model.right_caculate(order_score, batch[1], threshold=0.5)
|
381 |
+
self.loss_log.update_token(batch[1].shape[0], [token_right])
|
382 |
+
return loss
|
383 |
+
|
384 |
+
def train(self, _tqdm, dataloader, task=0):
|
385 |
+
# cycle train
|
386 |
+
loss_weight, kpi_loss_weight, kpi_loss_dict, _output = None, None, None, None
|
387 |
+
# dataloader = zip(self.train_dataloader, cycle(self.train_dataloader_kg))
|
388 |
+
self.loss_log.acc_init()
|
389 |
+
# 如果self.train_dataloader比self.train_dataloader_kg长则会使得后者训练不完全
|
390 |
+
accumulation_steps = self.args.accumulation_steps_dict[task]
|
391 |
+
torch.cuda.empty_cache()
|
392 |
+
|
393 |
+
for batch in dataloader:
|
394 |
+
# with autocast():
|
395 |
+
loss = self.args.mask_loss_scale * self.loss_output(batch, task)
|
396 |
+
# 如果是同时训练的话使用迭代器的方法得到另外的epoch
|
397 |
+
if self.args.train_together and self.args.train_strategy > 1:
|
398 |
+
for t in range(1, self.args.train_strategy):
|
399 |
+
try:
|
400 |
+
batch = next(self.dataloader_list[t])
|
401 |
+
except StopIteration:
|
402 |
+
self.dist_step(t)
|
403 |
+
self.dataloader_list[t] = iter(self.task_dataloader_choose(t))
|
404 |
+
batch = next(self.dataloader_list[t])
|
405 |
+
# 选择对应的模型得到loss
|
406 |
+
# torch.cuda.empty_cache()
|
407 |
+
loss += self.loss_output(batch, t)
|
408 |
+
# torch.cuda.empty_cache()
|
409 |
+
loss = loss / accumulation_steps
|
410 |
+
self.scaler.scale(loss).backward()
|
411 |
+
# loss.backward()
|
412 |
+
if self.args.dist:
|
413 |
+
loss = reduce_value(loss, average=True)
|
414 |
+
# torch.cuda.empty_cache()
|
415 |
+
self.step += 1
|
416 |
+
self.total_step_sum += 1
|
417 |
+
|
418 |
+
# -------- 模型统计 --------
|
419 |
+
if not self.args.dist or is_main_process():
|
420 |
+
self.output_statistic(loss, _output)
|
421 |
+
acc_descrip = f"Acc: {self.loss_log.get_token_acc()}" if self.loss_log.get_token_acc() > 0 else ""
|
422 |
+
_tqdm.set_description(f'Train | step [{self.step}/{self.args.total_steps}] {acc_descrip} LR [{self.lr}] Loss {self.loss_log.get_loss():.5f} ')
|
423 |
+
if self.step % self.args.eval_step == 0 and self.step > 0:
|
424 |
+
self.loss_log.update(self.curr_loss)
|
425 |
+
self.update_loss_log()
|
426 |
+
# -------- 梯度累计与模型更新 --------
|
427 |
+
if self.step % accumulation_steps == 0 and self.step > 0:
|
428 |
+
# 更新优化器
|
429 |
+
self.scaler.unscale_(self.optimizer)
|
430 |
+
for model in self.model_list:
|
431 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.clip)
|
432 |
+
|
433 |
+
# self.optimizer.step()
|
434 |
+
scale = self.scaler.get_scale()
|
435 |
+
self.scaler.step(self.optimizer)
|
436 |
+
|
437 |
+
self.scaler.update()
|
438 |
+
skip_lr_sched = (scale > self.scaler.get_scale())
|
439 |
+
if not skip_lr_sched:
|
440 |
+
# pdb.set_trace()
|
441 |
+
self.scheduler.step()
|
442 |
+
|
443 |
+
if not self.args.dist or is_main_process():
|
444 |
+
# pdb.set_trace()
|
445 |
+
self.lr = self.scheduler.get_last_lr()[-1]
|
446 |
+
self.writer.add_scalars("lr", {"lr": self.lr}, self.total_step_sum)
|
447 |
+
# 模型update
|
448 |
+
for model in self.model_list:
|
449 |
+
model.zero_grad(set_to_none=True)
|
450 |
+
|
451 |
+
if self.args.dist:
|
452 |
+
torch.cuda.synchronize(self.args.device)
|
453 |
+
return self.curr_loss, self.curr_loss_dic
|
454 |
+
|
455 |
+
def output_statistic(self, loss, output):
|
456 |
+
# 统计模型的各种输出
|
457 |
+
self.curr_loss += loss.item()
|
458 |
+
if output is None:
|
459 |
+
return
|
460 |
+
for key in output['loss_dic'].keys():
|
461 |
+
self.curr_loss_dic[key] += output['loss_dic'][key]
|
462 |
+
if 'kpi_loss_dict' in output and output['kpi_loss_dict'] is not None:
|
463 |
+
for key in output['kpi_loss_dict'].keys():
|
464 |
+
self.curr_kpi_loss_dic[key] += output['kpi_loss_dict'][key]
|
465 |
+
if 'loss_weight' in output and output['loss_weight'] is not None:
|
466 |
+
self.loss_weight = output['loss_weight']
|
467 |
+
# 需要用dict来判断
|
468 |
+
if 'kpi_loss_weight' in output and output['kpi_loss_weight'] is not None:
|
469 |
+
self.kpi_loss_weight = output['kpi_loss_weight']
|
470 |
+
|
471 |
+
def update_loss_log(self, task=0):
|
472 |
+
# 把统计的模型各种输出存下来
|
473 |
+
# https://zhuanlan.zhihu.com/p/382950853
|
474 |
+
# "mask_loss": self.curr_loss_dic['mask_loss'], "ke_loss": self.curr_loss_dic['ke_loss']
|
475 |
+
vis_dict = {"train_loss": self.curr_loss}
|
476 |
+
vis_dict.update(self.curr_loss_dic)
|
477 |
+
self.writer.add_scalars("loss", vis_dict, self.total_step_sum)
|
478 |
+
if self.loss_weight is not None:
|
479 |
+
# 预训练
|
480 |
+
loss_weight_dic = {}
|
481 |
+
if self.args.train_strategy == 1:
|
482 |
+
loss_weight_dic["mask"] = 1 / (self.loss_weight[0]**2)
|
483 |
+
if self.args.use_NumEmb:
|
484 |
+
loss_weight_dic["kpi"] = 1 / (self.loss_weight[1]**2)
|
485 |
+
vis_kpi_dic = {"recover": 1 / (self.kpi_loss_weight[0]**2), "classifier": 1 / (self.kpi_loss_weight[1]**2)}
|
486 |
+
if self.args.contrastive_loss and len(self.kpi_loss_weight) > 2:
|
487 |
+
vis_kpi_dic.update({"contrastive": 1 / (self.kpi_loss_weight[2]**2)})
|
488 |
+
self.writer.add_scalars("kpi_loss_weight", vis_kpi_dic, self.total_step_sum)
|
489 |
+
self.writer.add_scalars("kpi_loss", self.curr_kpi_loss_dic, self.total_step_sum)
|
490 |
+
self.writer.add_scalars("loss_weight", loss_weight_dic, self.total_step_sum)
|
491 |
+
# TODO: Finetune
|
492 |
+
|
493 |
+
# init log loss
|
494 |
+
self.curr_loss = 0.
|
495 |
+
for key in self.curr_loss_dic:
|
496 |
+
self.curr_loss_dic[key] = 0.
|
497 |
+
if len(self.curr_kpi_loss_dic) > 0:
|
498 |
+
for key in self.curr_kpi_loss_dic:
|
499 |
+
self.curr_kpi_loss_dic[key] = 0.
|
500 |
+
|
501 |
+
# TODO: Finetune 阶段
|
502 |
+
def eval(self):
|
503 |
+
self.model.eval()
|
504 |
+
torch.cuda.empty_cache()
|
505 |
+
|
506 |
+
def mask_test(self, test_log):
|
507 |
+
# 如果大于1 就无法mask测试
|
508 |
+
assert self.args.train_ratio < 1
|
509 |
+
topk = (1, 100, 500)
|
510 |
+
test_log.acc_init(topk)
|
511 |
+
# 做 mask 预测的时候需要进入训练模式,以获得随机mask的token
|
512 |
+
self.args.only_test = 0
|
513 |
+
self.dataloader_init(self.seq_test_set)
|
514 |
+
# pdb.set_trace()
|
515 |
+
sz_test = len(self.train_dataloader)
|
516 |
+
loss_sum = 0
|
517 |
+
with tqdm(total=sz_test) as _tqdm: # 使用需要的参数对tqdm进行初始化
|
518 |
+
for step, batch in enumerate(self.train_dataloader):
|
519 |
+
# DONE: 写好mask_prediction实现mask预测
|
520 |
+
with torch.no_grad():
|
521 |
+
token_num, token_right, loss = self.model.mask_prediction(batch, len(self.tokenizer), topk)
|
522 |
+
test_log.update_token(token_num, token_right)
|
523 |
+
loss_sum += loss
|
524 |
+
# test_log.update_word(word_num, word_right)
|
525 |
+
_tqdm.update(1)
|
526 |
+
_tqdm.set_description(f'Test | step [{step}/{sz_test}] Top{topk} Token_Acc: {test_log.get_token_acc()}')
|
527 |
+
print(f"perplexity: {loss_sum}")
|
528 |
+
# 训练模式复位
|
529 |
+
self.args.only_test = 1
|
530 |
+
# if topk is not None:
|
531 |
+
print(f"Top{topk} acc is {test_log.get_token_acc()}")
|
532 |
+
|
533 |
+
def emb_generate(self, path_gen):
|
534 |
+
assert len(self.args.path_gen) > 0 or path_gen is not None
|
535 |
+
data_path = self.args.data_path
|
536 |
+
if path_gen is None:
|
537 |
+
path_gen = self.args.path_gen
|
538 |
+
with open(osp.join(data_path, 'downstream_task', f'{path_gen}.json'), "r") as fp:
|
539 |
+
data = json.load(fp)
|
540 |
+
print(f"read file {path_gen} done!")
|
541 |
+
test_set = SeqDataset(data)
|
542 |
+
self.dataloader_init(test_set)
|
543 |
+
sz_test = len(self.train_dataloader)
|
544 |
+
all_emb_dic = defaultdict(list)
|
545 |
+
emb_output = {}
|
546 |
+
all_emb_ent = []
|
547 |
+
# tps = ['cls', 'last_avg', 'last2avg', 'last3avg', 'first_last_avg']
|
548 |
+
tps = ['cls', 'last_avg']
|
549 |
+
# with tqdm(total=sz_test) as _tqdm:
|
550 |
+
for step, batch in enumerate(self.train_dataloader):
|
551 |
+
for tp in tps:
|
552 |
+
with torch.no_grad():
|
553 |
+
batch_embedding = self.model.cls_embedding(batch, tp=tp)
|
554 |
+
# batch_embedding = self.model.cls_embedding(batch, tp=tp)
|
555 |
+
if tp in self.args.model_name and self.ke_model is not None:
|
556 |
+
batch_embedding_ent = self.ke_model.get_embedding(batch_embedding, is_ent=True)
|
557 |
+
# batch_embedding_ent = self.ke_model(batch, self.model)
|
558 |
+
batch_embedding_ent = batch_embedding_ent.cpu()
|
559 |
+
all_emb_ent.append(batch_embedding_ent)
|
560 |
+
|
561 |
+
batch_embedding = batch_embedding.cpu()
|
562 |
+
all_emb_dic[tp].append(batch_embedding)
|
563 |
+
# _tqdm.update(1)
|
564 |
+
# _tqdm.set_description(f'Test | step [{step}/{sz_test}]')
|
565 |
+
torch.cuda.empty_cache()
|
566 |
+
for tp in tps:
|
567 |
+
emb_output[tp] = torch.cat(all_emb_dic[tp])
|
568 |
+
assert emb_output[tp].shape[0] == len(data)
|
569 |
+
if len(all_emb_ent) > 0:
|
570 |
+
emb_output_ent = torch.cat(all_emb_ent)
|
571 |
+
# 后缀
|
572 |
+
save_path = osp.join(data_path, 'downstream_task', 'output')
|
573 |
+
os.makedirs(save_path, exist_ok=True)
|
574 |
+
for tp in tps:
|
575 |
+
save_dir = osp.join(save_path, f'{path_gen}_emb_{self.args.model_name.replace("DistributedDataParallel", "")}_{tp}.pt')
|
576 |
+
torch.save(emb_output[tp], save_dir)
|
577 |
+
# 有训练好的实体embedding可使用
|
578 |
+
if len(all_emb_ent) > 0:
|
579 |
+
save_dir = osp.join(save_path, f'{path_gen}_emb_{self.args.model_name.replace("DistributedDataParallel", "")}_ent.pt')
|
580 |
+
torch.save(emb_output_ent, save_dir)
|
581 |
+
|
582 |
+
def KGE_test(self):
|
583 |
+
# 直接用KG全集进行kge的测试
|
584 |
+
sz_test = len(self.kg_train_set)
|
585 |
+
# 先转换数据
|
586 |
+
ent_set = set()
|
587 |
+
rel_set = set()
|
588 |
+
with tqdm(total=sz_test) as _tqdm: # 使用需要的参数对tqdm进行初始化
|
589 |
+
_tqdm.set_description('trans entity/relation ID')
|
590 |
+
for batch in self.kg_train_set:
|
591 |
+
ent_set.add(batch[0])
|
592 |
+
ent_set.add(batch[2])
|
593 |
+
rel_set.add(batch[1])
|
594 |
+
_tqdm.update(1)
|
595 |
+
all_ent, all_rel = list(ent_set), list(rel_set)
|
596 |
+
nent, nrel = len(all_ent), len(all_rel)
|
597 |
+
ent_dic, rel_dic = {}, {}
|
598 |
+
for i in range(nent):
|
599 |
+
ent_dic[all_ent[i]] = i
|
600 |
+
for i in range(nrel):
|
601 |
+
rel_dic[all_rel[i]] = i
|
602 |
+
id_format_triple = []
|
603 |
+
with tqdm(total=sz_test) as _tqdm:
|
604 |
+
_tqdm.set_description('trans triple ID')
|
605 |
+
for triple in self.kg_train_set:
|
606 |
+
id_format_triple.append((ent_dic[triple[0]], rel_dic[triple[1]], ent_dic[triple[2]]))
|
607 |
+
_tqdm.update(1)
|
608 |
+
|
609 |
+
# pdb.set_trace()
|
610 |
+
# 生成实体embedding并且保存
|
611 |
+
ent_dataset = KGDataset(all_ent)
|
612 |
+
rel_dataset = KGDataset(all_rel)
|
613 |
+
|
614 |
+
ent_dataloader = DataLoader(
|
615 |
+
ent_dataset,
|
616 |
+
batch_size=self.args.batch_size * 32,
|
617 |
+
num_workers=self.args.workers,
|
618 |
+
persistent_workers=True,
|
619 |
+
shuffle=False
|
620 |
+
)
|
621 |
+
rel_dataloader = DataLoader(
|
622 |
+
rel_dataset,
|
623 |
+
batch_size=self.args.batch_size * 32,
|
624 |
+
num_workers=self.args.workers,
|
625 |
+
persistent_workers=True,
|
626 |
+
shuffle=False
|
627 |
+
)
|
628 |
+
|
629 |
+
sz_test = len(ent_dataloader) + len(rel_dataloader)
|
630 |
+
with tqdm(total=sz_test) as _tqdm:
|
631 |
+
ent_emb = []
|
632 |
+
rel_emb = []
|
633 |
+
step = 0
|
634 |
+
_tqdm.set_description('get the ent embedding')
|
635 |
+
with torch.no_grad():
|
636 |
+
for batch in ent_dataloader:
|
637 |
+
batch = self.tokenizer.batch_encode_plus(
|
638 |
+
batch,
|
639 |
+
padding='max_length',
|
640 |
+
max_length=self.args.maxlength,
|
641 |
+
truncation=True,
|
642 |
+
return_tensors="pt",
|
643 |
+
return_token_type_ids=False,
|
644 |
+
return_attention_mask=True,
|
645 |
+
add_special_tokens=False
|
646 |
+
)
|
647 |
+
|
648 |
+
batch_emb = self.model.cls_embedding(batch, tp=self.args.plm_emb_type)
|
649 |
+
batch_emb = self.ke_model.get_embedding(batch_emb, is_ent=True)
|
650 |
+
|
651 |
+
ent_emb.append(batch_emb.cpu())
|
652 |
+
_tqdm.update(1)
|
653 |
+
step += 1
|
654 |
+
torch.cuda.empty_cache()
|
655 |
+
_tqdm.set_description(f'ENT emb: [{step}/{sz_test}]')
|
656 |
+
|
657 |
+
_tqdm.set_description('get the rel embedding')
|
658 |
+
for batch in rel_dataloader:
|
659 |
+
batch = self.tokenizer.batch_encode_plus(
|
660 |
+
batch,
|
661 |
+
padding='max_length',
|
662 |
+
max_length=self.args.maxlength,
|
663 |
+
truncation=True,
|
664 |
+
return_tensors="pt",
|
665 |
+
return_token_type_ids=False,
|
666 |
+
return_attention_mask=True,
|
667 |
+
add_special_tokens=False
|
668 |
+
)
|
669 |
+
batch_emb = self.model.cls_embedding(batch, tp=self.args.plm_emb_type)
|
670 |
+
batch_emb = self.ke_model.get_embedding(batch_emb, is_ent=False)
|
671 |
+
# batch_emb = self.model.get_embedding(batch, is_ent=False)
|
672 |
+
rel_emb.append(batch_emb.cpu())
|
673 |
+
_tqdm.update(1)
|
674 |
+
step += 1
|
675 |
+
torch.cuda.empty_cache()
|
676 |
+
_tqdm.set_description(f'REL emb: [{step}/{sz_test}]')
|
677 |
+
|
678 |
+
all_ent_emb = torch.cat(ent_emb).cuda()
|
679 |
+
all_rel_emb = torch.cat(rel_emb).cuda()
|
680 |
+
# embedding:emb_output
|
681 |
+
# dim 256
|
682 |
+
kge_model_for_test = KGEModel(nentity=len(all_ent), nrelation=len(all_rel), hidden_dim=self.args.ke_dim,
|
683 |
+
gamma=self.args.ke_margin, entity_embedding=all_ent_emb, relation_embedding=all_rel_emb).cuda()
|
684 |
+
if self.args.ke_test_num > 0:
|
685 |
+
test_triples = id_format_triple[:self.args.ke_test_num]
|
686 |
+
else:
|
687 |
+
test_triples = id_format_triple
|
688 |
+
with torch.no_grad():
|
689 |
+
metrics = kge_model_for_test.test_step(test_triples=test_triples, all_true_triples=id_format_triple, args=self.args, nentity=len(all_ent), nrelation=len(all_rel))
|
690 |
+
# pdb.set_trace()
|
691 |
+
print(f"result:{metrics}")
|
692 |
+
|
693 |
+
def OD_test(self):
|
694 |
+
# data_path = self.args.data_path
|
695 |
+
# with open(osp.join(data_path, f'{self.args.order_test_name}.json'), "r") as fp:
|
696 |
+
# data = json.load(fp)
|
697 |
+
self.od_model.eval()
|
698 |
+
test_log = Loss_log()
|
699 |
+
test_log.acc_init()
|
700 |
+
sz_test = len(self.train_dataloader)
|
701 |
+
all_emb_ent = []
|
702 |
+
with tqdm(total=sz_test) as _tqdm: # 使用需要的参数对tqdm进行初始化
|
703 |
+
for step, batch in enumerate(self.train_dataloader):
|
704 |
+
with torch.no_grad():
|
705 |
+
emb = self.model.cls_embedding(batch[0], tp=self.args.plm_emb_type)
|
706 |
+
out_emb = self.od_model.encode(emb)
|
707 |
+
emb_cpu = out_emb.cpu()
|
708 |
+
all_emb_ent.append(emb_cpu)
|
709 |
+
order_score = self.od_model.predict(emb)
|
710 |
+
token_right = self.od_model.right_caculate(order_score, batch[1], threshold=self.args.order_threshold)
|
711 |
+
test_log.update_token(batch[1].shape[0], [token_right])
|
712 |
+
_tqdm.update(1)
|
713 |
+
_tqdm.set_description(f'Test | step [{step}/{sz_test}] Acc: {test_log.get_token_acc()}')
|
714 |
+
|
715 |
+
emb_output = torch.cat(all_emb_ent)
|
716 |
+
data_path = self.args.data_path
|
717 |
+
save_path = osp.join(data_path, 'downstream_task', 'output')
|
718 |
+
os.makedirs(save_path, exist_ok=True)
|
719 |
+
save_dir = osp.join(save_path, f'ratio{self.args.train_ratio}_{emb_output.shape[0]}emb_{self.args.model_name.replace("DistributedDataParallel", "")}.pt')
|
720 |
+
torch.save(emb_output, save_dir)
|
721 |
+
print(f"save {emb_output.shape[0]} embeddings done...")
|
722 |
+
|
723 |
+
@ torch.no_grad()
|
724 |
+
def test(self, path_gen=None):
|
725 |
+
test_log = Loss_log()
|
726 |
+
self.model.eval()
|
727 |
+
if not (self.args.mask_test or self.args.embed_gen or self.args.ke_test or len(self.args.order_test_name) > 0):
|
728 |
+
return
|
729 |
+
if self.args.mask_test:
|
730 |
+
self.mask_test(test_log)
|
731 |
+
if self.args.embed_gen:
|
732 |
+
self.emb_generate(path_gen)
|
733 |
+
if self.args.ke_test:
|
734 |
+
self.KGE_test()
|
735 |
+
if len(self.args.order_test_name) > 0:
|
736 |
+
runner.OD_test()
|
737 |
+
|
738 |
+
def _load_model(self, model, name):
|
739 |
+
if model is None:
|
740 |
+
return None
|
741 |
+
# 没有训练过
|
742 |
+
_name = name if name[:3] not in ['od_', 'ke_'] else name[3:]
|
743 |
+
save_path = osp.join(self.args.data_path, 'save', _name)
|
744 |
+
save_name = osp.join(save_path, f'{name}.pkl')
|
745 |
+
if not osp.exists(save_path) or not osp.exists(save_name):
|
746 |
+
return model.cuda()
|
747 |
+
# 载入模型
|
748 |
+
if 'Distribute' in self.args.model_name:
|
749 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(os.path.join(save_name), map_location=self.args.device).items()})
|
750 |
+
else:
|
751 |
+
model.load_state_dict(torch.load(save_name, map_location=self.args.device))
|
752 |
+
model.cuda()
|
753 |
+
if self.rank == 0:
|
754 |
+
print(f"loading model [{name}.pkl] done!")
|
755 |
+
|
756 |
+
return model
|
757 |
+
|
758 |
+
def _save_model(self, stage=''):
|
759 |
+
model_name = type(self.model).__name__
|
760 |
+
# TODO: path
|
761 |
+
save_path = osp.join(self.args.data_path, 'save')
|
762 |
+
os.makedirs(save_path, exist_ok=True)
|
763 |
+
if self.args.train_strategy == 1:
|
764 |
+
save_name = f'{self.args.exp_name}_{self.args.exp_id}_s{self.args.random_seed}{stage}'
|
765 |
+
else:
|
766 |
+
save_name = f'{self.args.exp_name}_{self.args.exp_id}_s{self.args.random_seed}_{self.args.plm_emb_type}{stage}'
|
767 |
+
save_path = osp.join(save_path, save_name)
|
768 |
+
os.makedirs(save_path, exist_ok=True)
|
769 |
+
# 预训练模型保存
|
770 |
+
self._save(self.model, save_path, save_name)
|
771 |
+
|
772 |
+
# 下游模型保存
|
773 |
+
save_name_od = f'od_{save_name}'
|
774 |
+
self._save(self.od_model, save_path, save_name_od)
|
775 |
+
save_name_ke = f'ke_{save_name}'
|
776 |
+
self._save(self.ke_model, save_path, save_name_ke)
|
777 |
+
return save_path
|
778 |
+
|
779 |
+
def _save(self, model, save_path, save_name):
|
780 |
+
if model is None:
|
781 |
+
return
|
782 |
+
if self.args.save_model:
|
783 |
+
torch.save(model.state_dict(), osp.join(save_path, f'{save_name}.pkl'))
|
784 |
+
print(f"saving {save_name} done!")
|
785 |
+
|
786 |
+
if self.args.save_pretrain and not save_name.startswith('od_') and not save_name.startswith('ke_'):
|
787 |
+
self.tokenizer.save_pretrained(osp.join(self.args.plm_path, f'{save_name}'))
|
788 |
+
self.model.encoder.save_pretrained(osp.join(self.args.plm_path, f'{save_name}'))
|
789 |
+
print(f"saving [pretrained] {save_name} done!")
|
790 |
+
|
791 |
+
|
792 |
+
if __name__ == '__main__':
|
793 |
+
cfg = cfg()
|
794 |
+
cfg.get_args()
|
795 |
+
cfgs = cfg.update_train_configs()
|
796 |
+
set_seed(cfgs.random_seed)
|
797 |
+
# 初始化各进程环境
|
798 |
+
# pdb.set_trace()
|
799 |
+
if cfgs.dist and not cfgs.only_test:
|
800 |
+
init_distributed_mode(args=cfgs)
|
801 |
+
# cfgs.lr *= cfgs.world_size
|
802 |
+
# cfgs.ke_lr *= cfgs.world_size
|
803 |
+
else:
|
804 |
+
# 下面这条语句在并行的时候可能内存泄漏,导致无法停止
|
805 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
806 |
+
rank = cfgs.rank
|
807 |
+
|
808 |
+
writer, logger = None, None
|
809 |
+
if rank == 0:
|
810 |
+
# 如果并行则只有一种情况打印
|
811 |
+
logger = initialize_exp(cfgs)
|
812 |
+
logger_path = get_dump_path(cfgs)
|
813 |
+
cfgs.time_stamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
|
814 |
+
comment = f'bath_size={cfgs.batch_size} exp_id={cfgs.exp_id}'
|
815 |
+
if not cfgs.no_tensorboard and not cfgs.only_test:
|
816 |
+
writer = SummaryWriter(log_dir=os.path.join(logger_path, 'tensorboard', cfgs.time_stamp), comment=comment)
|
817 |
+
|
818 |
+
cfgs.device = torch.device(cfgs.device)
|
819 |
+
|
820 |
+
# ----- Begin ----------
|
821 |
+
runner = Runner(cfgs, writer, logger, rank)
|
822 |
+
|
823 |
+
if cfgs.only_test:
|
824 |
+
if cfgs.embed_gen:
|
825 |
+
# 不需要生成的先搞定
|
826 |
+
if cfgs.mask_test or cfgs.ke_test:
|
827 |
+
runner.args.embed_gen = 0
|
828 |
+
runner.test()
|
829 |
+
runner.args.embed_gen = 1
|
830 |
+
# gen_dir = ['yht_data_merge', 'yht_data_whole5gc', 'yz_data_whole5gc', 'yz_data_merge', 'zyc_data_merge', 'zyc_data_whole5gc']
|
831 |
+
gen_dir = ['yht_serialize_withAttribute', 'yht_serialize_withoutAttr', 'yht_name_serialize', 'zyc_serialize_withAttribute', 'zyc_serialize_withoutAttr', 'zyc_name_serialize',
|
832 |
+
'yz_serialize_withAttribute', 'yz_serialize_withoutAttr', 'yz_name_serialize', 'yz_serialize_net']
|
833 |
+
# gen_dir = ['zyc_serialize_withAttribute', 'zyc_normal_serialize', 'zyc_data_whole5gc', 'zyc_data_merge', 'yht_normal_serialize',
|
834 |
+
# 'yht_serialize_withAttribute', 'yz_serialize_withAttribute', 'yz_serialize_net', 'yz_normal_serialize']
|
835 |
+
runner.args.mask_test, runner.args.ke_test = 0, 0
|
836 |
+
for item in gen_dir:
|
837 |
+
runner.test(item)
|
838 |
+
else:
|
839 |
+
runner.test()
|
840 |
+
else:
|
841 |
+
runner.run()
|
842 |
+
|
843 |
+
# ----- End ----------
|
844 |
+
if not cfgs.no_tensorboard and not cfgs.only_test and rank == 0:
|
845 |
+
writer.close()
|
846 |
+
logger.info("done!")
|
847 |
+
|
848 |
+
if cfgs.dist and not cfgs.only_test:
|
849 |
+
dist.barrier()
|
850 |
+
dist.destroy_process_group()
|
851 |
+
# print("shut down...")
|
KTeleBERT/model/HWBert.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import pdb
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
from random import *
|
8 |
+
import json
|
9 |
+
from packaging import version
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
from .Tool_model import AutomaticWeightedLoss
|
13 |
+
from .Numeric import AttenNumeric
|
14 |
+
from .KE_model import KE_model
|
15 |
+
# from modeling_transformer import Transformer
|
16 |
+
|
17 |
+
|
18 |
+
from .bert import BertModel, BertTokenizer, BertForMaskedLM, BertConfig
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from src.utils import torch_accuracy
|
23 |
+
# 4.21.2
|
24 |
+
|
25 |
+
|
26 |
+
def debug(input, kk, begin=None):
|
27 |
+
aaa = deepcopy(input[0])
|
28 |
+
if begin is None:
|
29 |
+
aaa.input_ids = input[0].input_ids[:kk]
|
30 |
+
aaa.attention_mask = input[0].attention_mask[:kk]
|
31 |
+
aaa.chinese_ref = input[0].chinese_ref[:kk]
|
32 |
+
aaa.kpi_ref = input[0].kpi_ref[:kk]
|
33 |
+
aaa.labels = input[0].labels[:kk]
|
34 |
+
else:
|
35 |
+
aaa.input_ids = input[0].input_ids[begin:kk]
|
36 |
+
aaa.attention_mask = input[0].attention_mask[begin:kk]
|
37 |
+
aaa.chinese_ref = input[0].chinese_ref[begin:kk]
|
38 |
+
aaa.kpi_ref = input[0].kpi_ref[begin:kk]
|
39 |
+
aaa.labels = input[0].labels[begin:kk]
|
40 |
+
|
41 |
+
return aaa
|
42 |
+
|
43 |
+
|
44 |
+
class HWBert(nn.Module):
|
45 |
+
def __init__(self, args):
|
46 |
+
super().__init__()
|
47 |
+
self.loss_awl = AutomaticWeightedLoss(args.awl_num, args)
|
48 |
+
self.args = args
|
49 |
+
self.config = BertConfig()
|
50 |
+
model_name = args.model_name
|
51 |
+
if args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']:
|
52 |
+
self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
|
53 |
+
# MacBert来初始化 predictions layer
|
54 |
+
if args.cls_head_init:
|
55 |
+
tmp = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', 'MacBert'))
|
56 |
+
self.encoder.cls.predictions = tmp.cls.predictions
|
57 |
+
else:
|
58 |
+
if not osp.exists(osp.join(args.data_root, 'transformer', args.model_name)):
|
59 |
+
model_name = 'MacBert'
|
60 |
+
self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
|
61 |
+
self.numeric_model = AttenNumeric(self.args)
|
62 |
+
|
63 |
+
# ----------------------- 主forward函数 ----------------------------------
|
64 |
+
def forward(self, input):
|
65 |
+
mask_loss, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(input)
|
66 |
+
mask_loss = mask_loss.loss
|
67 |
+
loss_dic = {}
|
68 |
+
if not self.args.use_kpi_loss:
|
69 |
+
kpi_loss = None
|
70 |
+
if kpi_loss is not None:
|
71 |
+
loss_sum = self.loss_awl(mask_loss, 0.3 * kpi_loss)
|
72 |
+
loss_dic['kpi_loss'] = kpi_loss.item()
|
73 |
+
else:
|
74 |
+
loss_sum = self.loss_awl(mask_loss)
|
75 |
+
loss_dic['mask_loss'] = mask_loss.item()
|
76 |
+
return {
|
77 |
+
'loss': loss_sum,
|
78 |
+
'loss_dic': loss_dic,
|
79 |
+
'loss_weight': self.loss_awl.params.tolist(),
|
80 |
+
'kpi_loss_weight': kpi_loss_weight,
|
81 |
+
'kpi_loss_dict': kpi_loss_dict
|
82 |
+
}
|
83 |
+
|
84 |
+
# loss_sum, loss_dic, self.loss_awl.params.tolist(), kpi_loss_weight, kpi_loss_dict
|
85 |
+
|
86 |
+
# ----------------------------------------------------------------
|
87 |
+
# 测试代码,计算mask是否正确
|
88 |
+
def mask_prediction(self, inputs, tokenizer_sz, topk=(1,)):
|
89 |
+
token_num, token_right, word_num, word_right = None, None, None, None
|
90 |
+
outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(inputs)
|
91 |
+
inputs = inputs['labels'].view(-1)
|
92 |
+
input_list = inputs.tolist()
|
93 |
+
# 被修改的词
|
94 |
+
change_token_index = [i for i, x in enumerate(input_list) if x != -100]
|
95 |
+
change_token = torch.tensor(change_token_index)
|
96 |
+
inputs_used = inputs[change_token]
|
97 |
+
pred = outputs.logits.view(-1, tokenizer_sz)
|
98 |
+
pred_used = pred[change_token].cpu()
|
99 |
+
# 返回的list
|
100 |
+
# 计算acc
|
101 |
+
acc, token_right = torch_accuracy(pred_used, inputs_used, topk)
|
102 |
+
# 计算混乱分数
|
103 |
+
|
104 |
+
token_num = inputs_used.shape[0]
|
105 |
+
# TODO: 添加word_num, word_right
|
106 |
+
# token_right:list
|
107 |
+
return token_num, token_right, outputs.loss.item()
|
108 |
+
|
109 |
+
def mask_forward(self, inputs):
|
110 |
+
kpi_ref = None
|
111 |
+
if 'kpi_ref' in inputs:
|
112 |
+
kpi_ref = inputs['kpi_ref']
|
113 |
+
|
114 |
+
outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.encoder(
|
115 |
+
input_ids=inputs['input_ids'].cuda(),
|
116 |
+
attention_mask=inputs['attention_mask'].cuda(),
|
117 |
+
# token_type_ids=inputs.token_type_ids.cuda(),
|
118 |
+
labels=inputs['labels'].cuda(),
|
119 |
+
kpi_ref=kpi_ref,
|
120 |
+
kpi_model=self.numeric_model
|
121 |
+
)
|
122 |
+
return outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict
|
123 |
+
|
124 |
+
# TODO: 垂直注意力考虑:https://github.com/lucidrains/axial-attention
|
125 |
+
|
126 |
+
def cls_embedding(self, inputs, tp='cls'):
|
127 |
+
hidden_states = self.encoder(
|
128 |
+
input_ids=inputs['input_ids'].cuda(),
|
129 |
+
attention_mask=inputs['attention_mask'].cuda(),
|
130 |
+
output_hidden_states=True)[0].hidden_states
|
131 |
+
if tp == 'cls':
|
132 |
+
return hidden_states[-1][:, 0]
|
133 |
+
else:
|
134 |
+
index_real = torch.tensor(inputs['input_ids'].clone().detach(), dtype=torch.bool)
|
135 |
+
res = []
|
136 |
+
for i in range(hidden_states[-1].shape[0]):
|
137 |
+
if tp == 'last_avg':
|
138 |
+
res.append(hidden_states[-1][i][index_real[i]][:-1].mean(dim=0))
|
139 |
+
elif tp == 'last2avg':
|
140 |
+
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1]).mean(dim=0))
|
141 |
+
elif tp == 'last3avg':
|
142 |
+
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1] + hidden_states[-3][i][index_real[i]][:-1]).mean(dim=0))
|
143 |
+
elif tp == 'first_last_avg':
|
144 |
+
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[1][i][index_real[i]][:-1]).mean(dim=0))
|
145 |
+
|
146 |
+
return torch.stack(res)
|
KTeleBERT/model/KE_model.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.metrics import average_precision_score
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pdb
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from collections import defaultdict
|
10 |
+
import os.path as osp
|
11 |
+
import json
|
12 |
+
|
13 |
+
|
14 |
+
class KE_model(nn.Module):
|
15 |
+
def __init__(self, args):
|
16 |
+
super().__init__()
|
17 |
+
"""
|
18 |
+
triple task: mask tail entity, total entity size-class classification
|
19 |
+
"""
|
20 |
+
"""
|
21 |
+
:param hidden: BERT model output size
|
22 |
+
"""
|
23 |
+
self.args = args
|
24 |
+
self.ke_dim = args.ke_dim
|
25 |
+
|
26 |
+
self.linear_ent = nn.Linear(args.hidden_size, self.ke_dim)
|
27 |
+
self.linear_rel = nn.Linear(args.hidden_size, self.ke_dim)
|
28 |
+
|
29 |
+
self.ke_margin = nn.Parameter(
|
30 |
+
torch.Tensor([args.ke_margin]),
|
31 |
+
requires_grad=False
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, batch, hw_model):
|
35 |
+
batch_triple = batch
|
36 |
+
pos_sample = batch_triple["positive_sample"]
|
37 |
+
neg_sample = batch_triple["negative_sample"]
|
38 |
+
neg_index = batch_triple["neg_index"]
|
39 |
+
|
40 |
+
# 节省显存
|
41 |
+
all_entity = []
|
42 |
+
all_entity_mask = []
|
43 |
+
for i in range(3):
|
44 |
+
all_entity.append(pos_sample[i]['input_ids'])
|
45 |
+
all_entity_mask.append(pos_sample[i]['attention_mask'])
|
46 |
+
|
47 |
+
all_entity = torch.cat(all_entity)
|
48 |
+
all_entity_mask = torch.cat(all_entity_mask)
|
49 |
+
entity_data = {'input_ids':all_entity, 'attention_mask':all_entity_mask}
|
50 |
+
entity_emb = hw_model.cls_embedding(entity_data, tp=self.args.plm_emb_type)
|
51 |
+
|
52 |
+
bs = pos_sample[0]['input_ids'].shape[0]
|
53 |
+
pos_sample_emb= [entity_emb[:bs], entity_emb[bs:2*bs], entity_emb[2*bs:3*bs]]
|
54 |
+
neg_sample_emb = entity_emb[neg_index]
|
55 |
+
mode = batch_triple["mode"]
|
56 |
+
# pos_score = self.get_score(pos_sample, hw_model)
|
57 |
+
# neg_score = self.get_score(pos_sample, hw_model, neg_sample, mode)
|
58 |
+
pos_score = self.get_score(pos_sample_emb, hw_model)
|
59 |
+
neg_score = self.get_score(pos_sample_emb, hw_model, neg_sample_emb, mode)
|
60 |
+
triple_loss = self.adv_loss(pos_score, neg_score, self.args)
|
61 |
+
|
62 |
+
return triple_loss
|
63 |
+
|
64 |
+
# pdb.set_trace()
|
65 |
+
# return emb.div_(emb.detach().norm(p=1, dim=-1, keepdim=True))
|
66 |
+
|
67 |
+
# KE loss
|
68 |
+
def tri2emb(self, triples, hw_model, negs=None, mode="single"):
|
69 |
+
"""Get embedding of triples.
|
70 |
+
This function get the embeddings of head, relation, and tail
|
71 |
+
respectively. each embedding has three dimensions.
|
72 |
+
Args:
|
73 |
+
triples (tensor): This tensor save triples id, which dimension is
|
74 |
+
[triples number, 3].
|
75 |
+
negs (tensor, optional): This tenosr store the id of the entity to
|
76 |
+
be replaced, which has one dimension. when negs is None, it is
|
77 |
+
in the test/eval phase. Defaults to None.
|
78 |
+
mode (str, optional): This arg indicates that the negative entity
|
79 |
+
will replace the head or tail entity. when it is 'single', it
|
80 |
+
means that entity will not be replaced. Defaults to 'single'.
|
81 |
+
Returns:
|
82 |
+
head_emb: Head entity embedding.
|
83 |
+
relation_emb: Relation embedding.
|
84 |
+
tail_emb: Tail entity embedding.
|
85 |
+
"""
|
86 |
+
|
87 |
+
if mode == "single":
|
88 |
+
head_emb = self.get_embedding(triples[0]).unsqueeze(1) # [bs, 1, dim]
|
89 |
+
relation_emb = self.get_embedding(triples[1], is_ent=False).unsqueeze(1) # [bs, 1, dim]
|
90 |
+
tail_emb = self.get_embedding(triples[2]).unsqueeze(1) # [bs, 1, dim]
|
91 |
+
|
92 |
+
elif mode == "head-batch" or mode == "head_predict":
|
93 |
+
if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding
|
94 |
+
# TODO:暂时不考虑KGC的测试情况
|
95 |
+
head_emb = self.ent_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
|
96 |
+
else:
|
97 |
+
head_emb = self.get_embedding(negs).reshape(-1, self.args.neg_num, self.args.ke_dim) # [bs, num_neg, dim]
|
98 |
+
relation_emb = self.get_embedding(triples[1], is_ent=False).unsqueeze(1) # [bs, 1, dim]
|
99 |
+
tail_emb = self.get_embedding(triples[2]).unsqueeze(1) # [bs, 1, dim]
|
100 |
+
|
101 |
+
elif mode == "tail-batch" or mode == "tail_predict":
|
102 |
+
head_emb = self.get_embedding(triples[0]).unsqueeze(1) # [bs, 1, dim]
|
103 |
+
relation_emb = self.get_embedding(triples[1], is_ent=False).unsqueeze(1) # [bs, 1, dim]
|
104 |
+
if negs is None:
|
105 |
+
tail_emb = self.ent_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
|
106 |
+
else:
|
107 |
+
# pdb.set_trace()
|
108 |
+
tail_emb = self.get_embedding(negs).reshape(-1, self.args.neg_num, self.args.ke_dim) # [bs, num_neg, dim]
|
109 |
+
|
110 |
+
return head_emb, relation_emb, tail_emb
|
111 |
+
|
112 |
+
def get_embedding(self, inputs, is_ent=True):
|
113 |
+
# pdb.set_trace()
|
114 |
+
if is_ent:
|
115 |
+
return self.linear_ent(inputs)
|
116 |
+
else:
|
117 |
+
return self.linear_rel(inputs)
|
118 |
+
|
119 |
+
def score_func(self, head_emb, relation_emb, tail_emb):
|
120 |
+
"""Calculating the score of triples.
|
121 |
+
|
122 |
+
The formula for calculating the score is :math:`\gamma - ||h + r - t||_F`
|
123 |
+
Args:
|
124 |
+
head_emb: The head entity embedding.
|
125 |
+
relation_emb: The relation embedding.
|
126 |
+
tail_emb: The tail entity embedding.
|
127 |
+
mode: Choose head-predict or tail-predict.
|
128 |
+
Returns:
|
129 |
+
score: The score of triples.
|
130 |
+
"""
|
131 |
+
score = (head_emb + relation_emb) - tail_emb
|
132 |
+
# pdb.set_trace()
|
133 |
+
score = self.ke_margin.item() - torch.norm(score, p=1, dim=-1)
|
134 |
+
return score
|
135 |
+
|
136 |
+
def get_score(self, triples, hw_model, negs=None, mode='single'):
|
137 |
+
"""The functions used in the training phase
|
138 |
+
|
139 |
+
Args:
|
140 |
+
triples: The triples ids, as (h, r, t), shape:[batch_size, 3].
|
141 |
+
negs: Negative samples, defaults to None.
|
142 |
+
mode: Choose head-predict or tail-predict, Defaults to 'single'.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
score: The score of triples.
|
146 |
+
"""
|
147 |
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, hw_model, negs, mode)
|
148 |
+
score = self.score_func(head_emb, relation_emb, tail_emb)
|
149 |
+
|
150 |
+
return score
|
151 |
+
|
152 |
+
def adv_loss(self, pos_score, neg_score, args):
|
153 |
+
"""Negative sampling loss with self-adversarial training. In math:
|
154 |
+
|
155 |
+
L=-\log \sigma\left(\gamma-d_{r}(\mathbf{h}, \mathbf{t})\right)-\sum_{i=1}^{n} p\left(h_{i}^{\prime}, r, t_{i}^{\prime}\right) \log \sigma\left(d_{r}\left(\mathbf{h}_{i}^{\prime}, \mathbf{t}_{i}^{\prime}\right)-\gamma\right)
|
156 |
+
|
157 |
+
Args:
|
158 |
+
pos_score: The score of positive samples.
|
159 |
+
neg_score: The score of negative samples.
|
160 |
+
subsampling_weight: The weight for correcting pos_score and neg_score.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
loss: The training loss for back propagation.
|
164 |
+
"""
|
165 |
+
neg_score = (F.softmax(neg_score * args.adv_temp, dim=1).detach()
|
166 |
+
* F.logsigmoid(-neg_score)).sum(dim=1) # shape:[bs]
|
167 |
+
pos_score = F.logsigmoid(pos_score).view(neg_score.shape[0]) # shape:[bs]
|
168 |
+
positive_sample_loss = - pos_score.mean()
|
169 |
+
negative_sample_loss = - neg_score.mean()
|
170 |
+
loss = (positive_sample_loss + negative_sample_loss) / 2
|
171 |
+
return loss
|
172 |
+
|
173 |
+
|
174 |
+
class KGEModel(nn.Module):
|
175 |
+
def __init__(self, nentity, nrelation, hidden_dim, gamma, entity_embedding, relation_embedding):
|
176 |
+
super(KGEModel, self).__init__()
|
177 |
+
self.nentity = nentity
|
178 |
+
self.nrelation = nrelation
|
179 |
+
self.hidden_dim = hidden_dim
|
180 |
+
|
181 |
+
self.gamma = nn.Parameter(
|
182 |
+
torch.Tensor([gamma]),
|
183 |
+
requires_grad=False
|
184 |
+
)
|
185 |
+
self.entity_embedding = entity_embedding
|
186 |
+
self.relation_embedding = relation_embedding
|
187 |
+
|
188 |
+
assert self.relation_embedding.shape[0] == nrelation
|
189 |
+
assert self.entity_embedding.shape[0] == nentity
|
190 |
+
|
191 |
+
def forward(self, sample, mode='single'):
|
192 |
+
'''
|
193 |
+
Forward function that calculate the score of a batch of triples.
|
194 |
+
In the 'single' mode, sample is a batch of triple.
|
195 |
+
In the 'head-batch' or 'tail-batch' mode, sample consists two part.
|
196 |
+
The first part is usually the positive sample.
|
197 |
+
And the second part is the entities in the negative samples.
|
198 |
+
Because negative samples and positive samples usually share two elements
|
199 |
+
in their triple ((head, relation) or (relation, tail)).
|
200 |
+
'''
|
201 |
+
|
202 |
+
if mode == 'single':
|
203 |
+
batch_size, negative_sample_size = sample.size(0), 1
|
204 |
+
|
205 |
+
head = torch.index_select(
|
206 |
+
self.entity_embedding,
|
207 |
+
dim=0,
|
208 |
+
index=sample[:, 0]
|
209 |
+
).unsqueeze(1)
|
210 |
+
|
211 |
+
relation = torch.index_select(
|
212 |
+
self.relation_embedding,
|
213 |
+
dim=0,
|
214 |
+
index=sample[:, 1]
|
215 |
+
).unsqueeze(1)
|
216 |
+
|
217 |
+
tail = torch.index_select(
|
218 |
+
self.entity_embedding,
|
219 |
+
dim=0,
|
220 |
+
index=sample[:, 2]
|
221 |
+
).unsqueeze(1)
|
222 |
+
|
223 |
+
elif mode == 'head-batch':
|
224 |
+
tail_part, head_part = sample
|
225 |
+
batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
|
226 |
+
|
227 |
+
head = torch.index_select(
|
228 |
+
self.entity_embedding,
|
229 |
+
dim=0,
|
230 |
+
index=head_part.view(-1)
|
231 |
+
).view(batch_size, negative_sample_size, -1)
|
232 |
+
|
233 |
+
relation = torch.index_select(
|
234 |
+
self.relation_embedding,
|
235 |
+
dim=0,
|
236 |
+
index=tail_part[:, 1]
|
237 |
+
).unsqueeze(1)
|
238 |
+
|
239 |
+
tail = torch.index_select(
|
240 |
+
self.entity_embedding,
|
241 |
+
dim=0,
|
242 |
+
index=tail_part[:, 2]
|
243 |
+
).unsqueeze(1)
|
244 |
+
|
245 |
+
elif mode == 'tail-batch':
|
246 |
+
head_part, tail_part = sample
|
247 |
+
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
|
248 |
+
|
249 |
+
head = torch.index_select(
|
250 |
+
self.entity_embedding,
|
251 |
+
dim=0,
|
252 |
+
index=head_part[:, 0]
|
253 |
+
).unsqueeze(1)
|
254 |
+
|
255 |
+
relation = torch.index_select(
|
256 |
+
self.relation_embedding,
|
257 |
+
dim=0,
|
258 |
+
index=head_part[:, 1]
|
259 |
+
).unsqueeze(1)
|
260 |
+
|
261 |
+
tail = torch.index_select(
|
262 |
+
self.entity_embedding,
|
263 |
+
dim=0,
|
264 |
+
index=tail_part.view(-1)
|
265 |
+
).view(batch_size, negative_sample_size, -1)
|
266 |
+
|
267 |
+
else:
|
268 |
+
raise ValueError('mode %s not supported' % mode)
|
269 |
+
|
270 |
+
score = self.TransE(head, relation, tail, mode)
|
271 |
+
|
272 |
+
return score
|
273 |
+
|
274 |
+
def TransE(self, head, relation, tail, mode):
|
275 |
+
if mode == 'head-batch':
|
276 |
+
score = head + (relation - tail)
|
277 |
+
else:
|
278 |
+
score = (head + relation) - tail
|
279 |
+
|
280 |
+
score = self.gamma.item() - torch.norm(score, p=1, dim=-1)
|
281 |
+
return score
|
282 |
+
|
283 |
+
@torch.no_grad()
|
284 |
+
def test_step(self, test_triples, all_true_triples, args, nentity, nrelation):
|
285 |
+
'''
|
286 |
+
Evaluate the model on test or valid datasets
|
287 |
+
'''
|
288 |
+
# Otherwise use standard (filtered) MRR, MR, HITS@1, HITS@3, and HITS@10 metrics
|
289 |
+
# Prepare dataloader for evaluation
|
290 |
+
test_dataloader_head = DataLoader(
|
291 |
+
KGTestDataset(
|
292 |
+
test_triples,
|
293 |
+
all_true_triples,
|
294 |
+
nentity,
|
295 |
+
nrelation,
|
296 |
+
'head-batch'
|
297 |
+
),
|
298 |
+
batch_size=args.batch_size,
|
299 |
+
num_workers=args.workers,
|
300 |
+
persistent_workers=True,
|
301 |
+
collate_fn=KGTestDataset.collate_fn
|
302 |
+
)
|
303 |
+
|
304 |
+
test_dataloader_tail = DataLoader(
|
305 |
+
KGTestDataset(
|
306 |
+
test_triples,
|
307 |
+
all_true_triples,
|
308 |
+
nentity,
|
309 |
+
nrelation,
|
310 |
+
'tail-batch'
|
311 |
+
),
|
312 |
+
batch_size=args.batch_size,
|
313 |
+
num_workers=args.workers,
|
314 |
+
persistent_workers=True,
|
315 |
+
collate_fn=KGTestDataset.collate_fn
|
316 |
+
)
|
317 |
+
|
318 |
+
test_dataset_list = [test_dataloader_head, test_dataloader_tail]
|
319 |
+
|
320 |
+
logs = []
|
321 |
+
|
322 |
+
step = 0
|
323 |
+
total_steps = sum([len(dataset) for dataset in test_dataset_list])
|
324 |
+
|
325 |
+
# pdb.set_trace()
|
326 |
+
with tqdm(total=total_steps) as _tqdm:
|
327 |
+
_tqdm.set_description(f'eval KGC')
|
328 |
+
for test_dataset in test_dataset_list:
|
329 |
+
for positive_sample, negative_sample, filter_bias, mode in test_dataset:
|
330 |
+
|
331 |
+
positive_sample = positive_sample.cuda()
|
332 |
+
negative_sample = negative_sample.cuda()
|
333 |
+
filter_bias = filter_bias.cuda()
|
334 |
+
|
335 |
+
batch_size = positive_sample.size(0)
|
336 |
+
|
337 |
+
score = self.forward((positive_sample, negative_sample), mode)
|
338 |
+
score += filter_bias
|
339 |
+
|
340 |
+
# Explicitly sort all the entities to ensure that there is no test exposure bias
|
341 |
+
argsort = torch.argsort(score, dim=1, descending=True)
|
342 |
+
|
343 |
+
if mode == 'head-batch':
|
344 |
+
positive_arg = positive_sample[:, 0]
|
345 |
+
elif mode == 'tail-batch':
|
346 |
+
positive_arg = positive_sample[:, 2]
|
347 |
+
else:
|
348 |
+
raise ValueError('mode %s not supported' % mode)
|
349 |
+
|
350 |
+
for i in range(batch_size):
|
351 |
+
# Notice that argsort is not ranking
|
352 |
+
# ranking = (argsort[i, :] == positive_arg[i]).nonzero()
|
353 |
+
ranking = (argsort[i, :] == positive_arg[i]).nonzero(as_tuple=False)
|
354 |
+
assert ranking.size(0) == 1
|
355 |
+
|
356 |
+
# ranking + 1 is the true ranking used in evaluation metrics
|
357 |
+
ranking = 1 + ranking.item()
|
358 |
+
logs.append({
|
359 |
+
'MRR': 1.0 / ranking,
|
360 |
+
'MR': float(ranking),
|
361 |
+
'HITS@1': 1.0 if ranking <= 1 else 0.0,
|
362 |
+
'HITS@3': 1.0 if ranking <= 3 else 0.0,
|
363 |
+
'HITS@10': 1.0 if ranking <= 10 else 0.0,
|
364 |
+
})
|
365 |
+
|
366 |
+
# if step % args.test_log_steps == 0:
|
367 |
+
# logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))
|
368 |
+
_tqdm.update(1)
|
369 |
+
_tqdm.set_description(f'KGC Eval:')
|
370 |
+
step += 1
|
371 |
+
|
372 |
+
metrics = {}
|
373 |
+
for metric in logs[0].keys():
|
374 |
+
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
|
375 |
+
|
376 |
+
return metrics
|
377 |
+
|
378 |
+
|
379 |
+
# 专门为KGE的测试设计一个dataset
|
380 |
+
class KGTestDataset(torch.utils.data.Dataset):
|
381 |
+
def __init__(self, triples, all_true_triples, nentity, nrelation, mode, head4rel_tail=None, tail4head_rel=None):
|
382 |
+
self.len = len(triples)
|
383 |
+
self.triple_set = set(all_true_triples)
|
384 |
+
self.triples = triples
|
385 |
+
|
386 |
+
# 需要统计得到
|
387 |
+
self.nentity = nentity
|
388 |
+
self.nrelation = nrelation
|
389 |
+
self.mode = mode
|
390 |
+
|
391 |
+
# 给定关系尾实体对应头实体
|
392 |
+
# print("build head4rel_tail")
|
393 |
+
# self.head4rel_tail = self.find_head4rel_tail()
|
394 |
+
# print("build tail4head_rel")
|
395 |
+
# self.tail4head_rel = self.find_tail4head_rel()
|
396 |
+
|
397 |
+
def __len__(self):
|
398 |
+
return self.len
|
399 |
+
|
400 |
+
def find_head4rel_tail(self):
|
401 |
+
ans = defaultdict(list)
|
402 |
+
for (h, r, t) in self.triple_set:
|
403 |
+
ans[(r, t)].append(h)
|
404 |
+
return ans
|
405 |
+
|
406 |
+
def find_tail4head_rel(self):
|
407 |
+
ans = defaultdict(list)
|
408 |
+
for (h, r, t) in self.triple_set:
|
409 |
+
ans[(h, r)].append(t)
|
410 |
+
return ans
|
411 |
+
|
412 |
+
def __getitem__(self, idx):
|
413 |
+
head, relation, tail = self.triples[idx]
|
414 |
+
|
415 |
+
if self.mode == 'head-batch':
|
416 |
+
tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set
|
417 |
+
else (-100, head) for rand_head in range(self.nentity)]
|
418 |
+
tmp[head] = (0, head)
|
419 |
+
elif self.mode == 'tail-batch':
|
420 |
+
tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set
|
421 |
+
else (-100, tail) for rand_tail in range(self.nentity)]
|
422 |
+
tmp[tail] = (0, tail)
|
423 |
+
else:
|
424 |
+
raise ValueError('negative batch mode %s not supported' % self.mode)
|
425 |
+
# if self.mode == 'head-batch':
|
426 |
+
#
|
427 |
+
# tmp = [(0, rand_head) if rand_head not in self.head4rel_tail[(relation, tail)]
|
428 |
+
# else (-100, head) for rand_head in range(self.nentity)]
|
429 |
+
# tmp[head] = (0, head)
|
430 |
+
# elif self.mode == 'tail-batch':
|
431 |
+
# tmp = [(0, rand_tail) if rand_tail not in self.tail4head_rel[(head, relation)]
|
432 |
+
# else (-100, tail) for rand_tail in range(self.nentity)]
|
433 |
+
# tmp[tail] = (0, tail)
|
434 |
+
# else:
|
435 |
+
# raise ValueError('negative batch mode %s not supported' % self.mode)
|
436 |
+
|
437 |
+
tmp = torch.LongTensor(tmp)
|
438 |
+
filter_bias = tmp[:, 0].float()
|
439 |
+
negative_sample = tmp[:, 1]
|
440 |
+
|
441 |
+
positive_sample = torch.LongTensor((head, relation, tail))
|
442 |
+
|
443 |
+
return positive_sample, negative_sample, filter_bias, self.mode
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def collate_fn(data):
|
447 |
+
positive_sample = torch.stack([_[0] for _ in data], dim=0)
|
448 |
+
negative_sample = torch.stack([_[1] for _ in data], dim=0)
|
449 |
+
filter_bias = torch.stack([_[2] for _ in data], dim=0)
|
450 |
+
mode = data[0][3]
|
451 |
+
return positive_sample, negative_sample, filter_bias, mode
|
KTeleBERT/model/Numeric.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
import numpy as np
|
8 |
+
import pdb
|
9 |
+
import math
|
10 |
+
from .Tool_model import AutomaticWeightedLoss
|
11 |
+
import os.path as osp
|
12 |
+
import json
|
13 |
+
|
14 |
+
|
15 |
+
def ortho_penalty(t):
|
16 |
+
return ((t @ t.T - torch.eye(t.shape[0]).cuda())**2).sum()
|
17 |
+
|
18 |
+
|
19 |
+
class AttenNumeric(nn.Module):
|
20 |
+
def __init__(self, config):
|
21 |
+
super(AttenNumeric, self).__init__()
|
22 |
+
# ----------- 加载kpi2id --------------------
|
23 |
+
kpi_file_path = osp.join(config.data_path, 'kpi2id.json')
|
24 |
+
|
25 |
+
with open(kpi_file_path, 'r') as f:
|
26 |
+
# pdb.set_trace()
|
27 |
+
kpi2id = json.load(f)
|
28 |
+
config.num_kpi = 303
|
29 |
+
# config.num_kpi = len(kpi2id)
|
30 |
+
# -------------------------------
|
31 |
+
|
32 |
+
self.config = config
|
33 |
+
self.fc = nn.Linear(1, config.hidden_size)
|
34 |
+
# self.actication = nn.ReLU()
|
35 |
+
self.actication = nn.LeakyReLU()
|
36 |
+
# self.embedding = nn.Linear(config.hidden_size, self.attention_head_size)
|
37 |
+
if config.contrastive_loss:
|
38 |
+
self.loss_awl = AutomaticWeightedLoss(3, config)
|
39 |
+
else:
|
40 |
+
self.loss_awl = AutomaticWeightedLoss(2, config)
|
41 |
+
self.encoder = AttNumEncoder(config)
|
42 |
+
self.decoder = AttNumDecoder(config)
|
43 |
+
self.classifier = NumClassifier(config)
|
44 |
+
self.ce_loss = nn.CrossEntropyLoss()
|
45 |
+
|
46 |
+
def contrastive_loss(self, hidden, kpi):
|
47 |
+
# in batch negative
|
48 |
+
bs_tmp = hidden.shape[0]
|
49 |
+
eye = torch.eye(bs_tmp).cuda()
|
50 |
+
hidden = F.normalize(hidden, dim=1)
|
51 |
+
# [12,12]
|
52 |
+
# 减去对角矩阵目的是防止对自身的相似程度影响了判断
|
53 |
+
hidden_sim = (torch.matmul(hidden, hidden.T) - eye) / 0.07
|
54 |
+
kpi = kpi.expand(-1, bs_tmp)
|
55 |
+
kpi_sim = torch.abs(kpi - kpi.T) + eye
|
56 |
+
kpi_sim = torch.min(kpi_sim, 1)[1]
|
57 |
+
sc_loss = self.ce_loss(hidden_sim, kpi_sim)
|
58 |
+
return sc_loss
|
59 |
+
|
60 |
+
def _encode(self, kpi, query):
|
61 |
+
kpi_emb = self.actication(self.fc(kpi))
|
62 |
+
# name_emb = self.embedding(query)
|
63 |
+
hidden, en_loss, scalar_list = self.encoder(kpi_emb, query)
|
64 |
+
|
65 |
+
# 两个及以下的对比学习没有意义
|
66 |
+
if self.config.contrastive_loss and hidden.shape[0] > 2:
|
67 |
+
con_loss = self.contrastive_loss(hidden.squeeze(1), kpi.squeeze(1))
|
68 |
+
else:
|
69 |
+
con_loss = None
|
70 |
+
hidden = self.actication(hidden)
|
71 |
+
assert query.shape[0] > 0
|
72 |
+
return hidden, en_loss, scalar_list, con_loss
|
73 |
+
|
74 |
+
def forward(self, kpi, query, kpi_id):
|
75 |
+
hidden, en_loss, scalar_list, con_loss = self._encode(kpi, query)
|
76 |
+
dec_kpi_score, de_loss = self.decoder(kpi, hidden)
|
77 |
+
cls_kpi, cls_loss = self.classifier(hidden, kpi_id)
|
78 |
+
if con_loss is not None:
|
79 |
+
# 0.001 * con_loss
|
80 |
+
loss_sum = self.loss_awl(de_loss, cls_loss, 0.1 * con_loss)
|
81 |
+
loss_all = loss_sum + en_loss
|
82 |
+
loss_dic = {'cls_loss': cls_loss.item(), 'reg_loss': de_loss.item(), 'orth_loss': en_loss.item(), 'con_loss': con_loss.item()}
|
83 |
+
# pdb.set_trace()
|
84 |
+
else:
|
85 |
+
loss_sum = self.loss_awl(de_loss, cls_loss)
|
86 |
+
loss_all = loss_sum + en_loss
|
87 |
+
loss_dic = {'cls_loss': cls_loss.item(), 'reg_loss': de_loss.item(), 'orth_loss': en_loss.item()}
|
88 |
+
|
89 |
+
return dec_kpi_score, cls_kpi, hidden, loss_all, self.loss_awl.params.tolist(), loss_dic, scalar_list
|
90 |
+
|
91 |
+
|
92 |
+
class AttNumEncoder(nn.Module):
|
93 |
+
def __init__(self, config):
|
94 |
+
super(AttNumEncoder, self).__init__()
|
95 |
+
self.num_l_layers = config.l_layers
|
96 |
+
self.layer = nn.ModuleList([AttNumLayer(config) for _ in range(self.num_l_layers)])
|
97 |
+
|
98 |
+
def forward(self, kpi_emb, name_emb):
|
99 |
+
loss = 0.
|
100 |
+
scalar_list = []
|
101 |
+
for layer_module in self.layer:
|
102 |
+
kpi_emb, orth_loss, scalar = layer_module(kpi_emb, name_emb)
|
103 |
+
loss += orth_loss
|
104 |
+
scalar_list.append(scalar)
|
105 |
+
return kpi_emb, loss, scalar_list
|
106 |
+
|
107 |
+
|
108 |
+
class AttNumDecoder(nn.Module):
|
109 |
+
def __init__(self, config):
|
110 |
+
super(AttNumDecoder, self).__init__()
|
111 |
+
self.dense_1 = nn.Linear(config.hidden_size, config.hidden_size)
|
112 |
+
self.dense_2 = nn.Linear(config.hidden_size, 1)
|
113 |
+
self.actication = nn.LeakyReLU()
|
114 |
+
self.loss_func = nn.MSELoss(reduction='mean')
|
115 |
+
|
116 |
+
def forward(self, kpi_label, hidden):
|
117 |
+
# 修复异常值
|
118 |
+
pre = self.actication(self.dense_2(self.actication(self.dense_1(hidden))))
|
119 |
+
loss = self.loss_func(pre, kpi_label)
|
120 |
+
# pdb.set_trace()
|
121 |
+
return pre, loss
|
122 |
+
|
123 |
+
|
124 |
+
class NumClassifier(nn.Module):
|
125 |
+
def __init__(self, config):
|
126 |
+
super(NumClassifier, self).__init__()
|
127 |
+
self.dense_1 = nn.Linear(config.hidden_size, int(config.hidden_size / 3))
|
128 |
+
self.dense_2 = nn.Linear(int(config.hidden_size / 3), config.num_kpi)
|
129 |
+
self.loss_func = nn.CrossEntropyLoss()
|
130 |
+
# self.actication = nn.ReLU()
|
131 |
+
self.actication = nn.LeakyReLU()
|
132 |
+
|
133 |
+
def forward(self, hidden, kpi_id):
|
134 |
+
hidden = self.actication(self.dense_1(hidden))
|
135 |
+
pre = self.actication(self.dense_2(hidden)).squeeze(1)
|
136 |
+
loss = self.loss_func(pre, kpi_id)
|
137 |
+
return pre, loss
|
138 |
+
|
139 |
+
|
140 |
+
class AttNumLayer(nn.Module):
|
141 |
+
def __init__(self, config):
|
142 |
+
super(AttNumLayer, self).__init__()
|
143 |
+
self.config = config
|
144 |
+
# 768 / 8 = 8
|
145 |
+
self.num_attention_heads = config.num_attention_heads
|
146 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 96
|
147 |
+
# self.head_size = config.hidden_size
|
148 |
+
|
149 |
+
# scaler
|
150 |
+
self.scalar = nn.Parameter(.3 * torch.ones(1, requires_grad=True))
|
151 |
+
self.key = nn.Parameter(torch.empty(self.num_attention_heads, self.attention_head_size))
|
152 |
+
|
153 |
+
self.dense_down = nn.Linear(config.hidden_size, 128)
|
154 |
+
self.dense_up = nn.Linear(128, config.hidden_size)
|
155 |
+
|
156 |
+
# name embedding
|
157 |
+
self.embedding = nn.Linear(config.hidden_size, self.attention_head_size)
|
158 |
+
# num_attention_heads�� value���� ת������k��
|
159 |
+
self.value = nn.Linear(config.hidden_size, config.hidden_size * self.num_attention_heads)
|
160 |
+
|
161 |
+
# add & norm
|
162 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
163 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
164 |
+
# 0.1
|
165 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
166 |
+
|
167 |
+
# for m in self.modules().modules():
|
168 |
+
# pdb.set_trace()
|
169 |
+
|
170 |
+
nn.init.kaiming_normal_(self.key, mode='fan_out', nonlinearity='leaky_relu')
|
171 |
+
# nn.init.orthogonal_(self.key)
|
172 |
+
|
173 |
+
def transpose_for_scores(self, x):
|
174 |
+
new_x_shape = x.size()[:-1] + (
|
175 |
+
self.num_attention_heads,
|
176 |
+
self.config.hidden_size,
|
177 |
+
)
|
178 |
+
x = x.view(*new_x_shape)
|
179 |
+
return x
|
180 |
+
# return x.permute(0, 2, 1, 3)
|
181 |
+
|
182 |
+
def forward(self, kpi_emb, name_emb):
|
183 |
+
# [64, 1, 96]
|
184 |
+
name_emb = self.embedding(name_emb)
|
185 |
+
|
186 |
+
mixed_value_layer = self.value(kpi_emb)
|
187 |
+
|
188 |
+
# [64, 1, 8, 768]
|
189 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
190 |
+
|
191 |
+
# key: [8, 96] self.key.transpose(-1, -2): [96, 8]
|
192 |
+
# name_emb: [64, 1, 96]
|
193 |
+
attention_scores = torch.matmul(name_emb, self.key.transpose(-1, -2))
|
194 |
+
# [64, 1, 8]
|
195 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
196 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
197 |
+
|
198 |
+
# This is actually dropping out entire tokens to attend to, which might
|
199 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
200 |
+
attention_probs = self.dropout(attention_probs)
|
201 |
+
attention_probs = attention_probs.unsqueeze(1)
|
202 |
+
# ��Ȩ��value��
|
203 |
+
# [64, 1, 1, 8] * [64, 1, 8, 768] = [64, 1, 1, 768]
|
204 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
205 |
+
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
206 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.config.hidden_size,)
|
207 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
208 |
+
# add & norm
|
209 |
+
output_emb = self.dense(context_layer)
|
210 |
+
output_emb = self.dropout(output_emb)
|
211 |
+
output_emb = self.LayerNorm(output_emb + self.scalar * self.dense_up(self.dense_down(kpi_emb)))
|
212 |
+
# output_emb = self.LayerNorm(self.LayerNorm(output_emb) + self.scalar * kpi_emb)
|
213 |
+
# pdb.set_trace()
|
214 |
+
wei = self.value.weight.chunk(8, dim=0)
|
215 |
+
orth_loss_value = sum([ortho_penalty(k) for k in wei])
|
216 |
+
# 0.01 * ortho_penalty(self.key) + ortho_penalty(self.value.weight)
|
217 |
+
orth_loss = 0.0001 * orth_loss_value + 0.0001 * ortho_penalty(self.dense.weight) + 0.01 * ((self.scalar[0])**2).sum()
|
218 |
+
return output_emb, orth_loss, self.scalar.tolist()[0]
|
KTeleBERT/model/OD_model.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import pdb
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
# from transformers import BertModel, BertTokenizer, BertForMaskedLM
|
8 |
+
import json
|
9 |
+
from packaging import version
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
class OD_model(nn.Module):
|
14 |
+
def __init__(self, args):
|
15 |
+
super().__init__()
|
16 |
+
self.args = args
|
17 |
+
self.order_num = args.order_num
|
18 |
+
if args.od_type == 'linear_cat':
|
19 |
+
# self.order_dense_1 = nn.Linear(args.hidden_size * self.order_num, args.hidden_size)
|
20 |
+
# self.order_dense_2 = nn.Linear(args.hidden_size, 1)
|
21 |
+
self.order_dense_1 = nn.Linear(args.hidden_size * self.order_num, args.hidden_size)
|
22 |
+
if self.args.num_od_layer > 0:
|
23 |
+
self.layer = nn.ModuleList([OD_Layer_linear(args) for _ in range(args.num_od_layer)])
|
24 |
+
|
25 |
+
self.order_dense_2 = nn.Linear(args.hidden_size, 1)
|
26 |
+
|
27 |
+
self.actication = nn.LeakyReLU()
|
28 |
+
self.bn = torch.nn.BatchNorm1d(args.hidden_size)
|
29 |
+
self.dp = nn.Dropout(p=args.hidden_dropout_prob)
|
30 |
+
self.loss_func = nn.BCEWithLogitsLoss()
|
31 |
+
# self.loss_func = nn.CrossEntropyLoss()
|
32 |
+
|
33 |
+
def forward(self, input, labels):
|
34 |
+
# input 切成两半
|
35 |
+
# 换方向拼接
|
36 |
+
loss_dic = {}
|
37 |
+
pre = self.predict(input)
|
38 |
+
# pdb.set_trace()
|
39 |
+
loss = self.loss_func(pre, labels.unsqueeze(1))
|
40 |
+
loss_dic['order_loss'] = loss.item()
|
41 |
+
return loss, loss_dic
|
42 |
+
|
43 |
+
def encode(self, input):
|
44 |
+
if self.args.num_od_layer > 0:
|
45 |
+
for layer_module in self.layer:
|
46 |
+
input = layer_module(input)
|
47 |
+
inputs = torch.chunk(input, 2, dim=0)
|
48 |
+
emb = torch.concat(inputs, dim=1)
|
49 |
+
return self.actication(self.order_dense_1(self.dp(emb)))
|
50 |
+
|
51 |
+
def predict(self, input):
|
52 |
+
return self.order_dense_2(self.bn(self.encode(input)))
|
53 |
+
|
54 |
+
def right_caculate(self, input, labels, threshold=0.5):
|
55 |
+
input = input.squeeze(1).tolist()
|
56 |
+
labels = labels.tolist()
|
57 |
+
right = 0
|
58 |
+
for i in range(len(input)):
|
59 |
+
if (input[i] >= threshold and labels[i] >= 0.5) or (input[i] < threshold and labels[i] < 0.5):
|
60 |
+
right += 1
|
61 |
+
return right
|
62 |
+
|
63 |
+
|
64 |
+
class OD_Layer_linear(nn.Module):
|
65 |
+
def __init__(self, args):
|
66 |
+
super().__init__()
|
67 |
+
self.args = args
|
68 |
+
self.dense = nn.Linear(args.hidden_size, args.hidden_size)
|
69 |
+
self.actication = nn.LeakyReLU()
|
70 |
+
self.bn = torch.nn.BatchNorm1d(args.hidden_size)
|
71 |
+
self.dropout = nn.Dropout(p=args.hidden_dropout_prob)
|
72 |
+
|
73 |
+
def forward(self, input):
|
74 |
+
return self.actication(self.bn(self.dense(self.dropout(input))))
|
KTeleBERT/model/Tool_model.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: UTF-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
# https://github.com/Mikoto10032/AutomaticWeightedLoss/blob/master/AutomaticWeightedLoss.py
|
7 |
+
|
8 |
+
|
9 |
+
class AutomaticWeightedLoss(nn.Module):
|
10 |
+
# '''
|
11 |
+
# automatically weighted multi-task loss
|
12 |
+
# Params��
|
13 |
+
# num: int��the number of loss
|
14 |
+
# x: multi-task loss
|
15 |
+
# Examples��
|
16 |
+
# loss1=1
|
17 |
+
# loss2=2
|
18 |
+
# awl = AutomaticWeightedLoss(2)
|
19 |
+
# loss_sum = awl(loss1, loss2)
|
20 |
+
# '''
|
21 |
+
def __init__(self, num=2, args=None):
|
22 |
+
super(AutomaticWeightedLoss, self).__init__()
|
23 |
+
if args is None or args.use_awl:
|
24 |
+
params = torch.ones(num, requires_grad=True)
|
25 |
+
self.params = torch.nn.Parameter(params)
|
26 |
+
else:
|
27 |
+
params = torch.ones(num, requires_grad=False)
|
28 |
+
self.params = torch.nn.Parameter(params, requires_grad=False)
|
29 |
+
|
30 |
+
def forward(self, *x):
|
31 |
+
loss_sum = 0
|
32 |
+
for i, loss in enumerate(x):
|
33 |
+
loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
|
34 |
+
return loss_sum
|
KTeleBERT/model/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .vector import Vector
|
2 |
+
# from .classifier import SimpleClassifier
|
3 |
+
# # from .updn import UpDn
|
4 |
+
# # from .ban import Ban
|
5 |
+
|
6 |
+
from .bert import (
|
7 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
8 |
+
BertForMaskedLM,
|
9 |
+
BertForMultipleChoice,
|
10 |
+
BertForNextSentencePrediction,
|
11 |
+
BertForPreTraining,
|
12 |
+
BertForQuestionAnswering,
|
13 |
+
BertForSequenceClassification,
|
14 |
+
BertForTokenClassification,
|
15 |
+
BertLayer,
|
16 |
+
BertLMHeadModel,
|
17 |
+
BertModel,
|
18 |
+
BertPreTrainedModel,
|
19 |
+
load_tf_weights_in_bert,
|
20 |
+
)
|
21 |
+
|
22 |
+
from .bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
23 |
+
from .bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
24 |
+
from .HWBert import HWBert
|
25 |
+
from .KE_model import KGEModel, KE_model
|
26 |
+
from .OD_model import OD_model
|
KTeleBERT/model/__pycache__/HWBert.cpython-38.pyc
ADDED
Binary file (4.42 kB). View file
|
|
KTeleBERT/model/__pycache__/KE_model.cpython-38.pyc
ADDED
Binary file (12.3 kB). View file
|
|
KTeleBERT/model/__pycache__/Numeric.cpython-38.pyc
ADDED
Binary file (6.88 kB). View file
|
|
KTeleBERT/model/__pycache__/OD_model.cpython-38.pyc
ADDED
Binary file (2.95 kB). View file
|
|
KTeleBERT/model/__pycache__/Tool_model.cpython-38.pyc
ADDED
Binary file (1.01 kB). View file
|
|
KTeleBERT/model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (849 Bytes). View file
|
|
KTeleBERT/model/bert/__init__.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
3 |
+
# module, but to preserve other warnings. So, don't check this module at all.
|
4 |
+
|
5 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
from typing import TYPE_CHECKING
|
20 |
+
|
21 |
+
from transformers.utils import (
|
22 |
+
OptionalDependencyNotAvailable,
|
23 |
+
_LazyModule,
|
24 |
+
is_flax_available,
|
25 |
+
is_tensorflow_text_available,
|
26 |
+
is_tf_available,
|
27 |
+
is_tokenizers_available,
|
28 |
+
is_torch_available,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
_import_structure = {
|
33 |
+
"configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"],
|
34 |
+
"tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
|
35 |
+
}
|
36 |
+
|
37 |
+
try:
|
38 |
+
if not is_tokenizers_available():
|
39 |
+
raise OptionalDependencyNotAvailable()
|
40 |
+
except OptionalDependencyNotAvailable:
|
41 |
+
pass
|
42 |
+
else:
|
43 |
+
_import_structure["tokenization_bert_fast"] = ["BertTokenizerFast"]
|
44 |
+
|
45 |
+
try:
|
46 |
+
if not is_torch_available():
|
47 |
+
raise OptionalDependencyNotAvailable()
|
48 |
+
except OptionalDependencyNotAvailable:
|
49 |
+
pass
|
50 |
+
else:
|
51 |
+
_import_structure["modeling_bert"] = [
|
52 |
+
"BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
53 |
+
"BertForMaskedLM",
|
54 |
+
"BertForMultipleChoice",
|
55 |
+
"BertForNextSentencePrediction",
|
56 |
+
"BertForPreTraining",
|
57 |
+
"BertForQuestionAnswering",
|
58 |
+
"BertForSequenceClassification",
|
59 |
+
"BertForTokenClassification",
|
60 |
+
"BertLayer",
|
61 |
+
"BertLMHeadModel",
|
62 |
+
"BertModel",
|
63 |
+
"BertPreTrainedModel",
|
64 |
+
"load_tf_weights_in_bert",
|
65 |
+
]
|
66 |
+
|
67 |
+
try:
|
68 |
+
if not is_tf_available():
|
69 |
+
raise OptionalDependencyNotAvailable()
|
70 |
+
except OptionalDependencyNotAvailable:
|
71 |
+
pass
|
72 |
+
else:
|
73 |
+
_import_structure["modeling_tf_bert"] = [
|
74 |
+
"TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
75 |
+
"TFBertEmbeddings",
|
76 |
+
"TFBertForMaskedLM",
|
77 |
+
"TFBertForMultipleChoice",
|
78 |
+
"TFBertForNextSentencePrediction",
|
79 |
+
"TFBertForPreTraining",
|
80 |
+
"TFBertForQuestionAnswering",
|
81 |
+
"TFBertForSequenceClassification",
|
82 |
+
"TFBertForTokenClassification",
|
83 |
+
"TFBertLMHeadModel",
|
84 |
+
"TFBertMainLayer",
|
85 |
+
"TFBertModel",
|
86 |
+
"TFBertPreTrainedModel",
|
87 |
+
]
|
88 |
+
try:
|
89 |
+
if not is_tensorflow_text_available():
|
90 |
+
raise OptionalDependencyNotAvailable()
|
91 |
+
except OptionalDependencyNotAvailable:
|
92 |
+
pass
|
93 |
+
else:
|
94 |
+
_import_structure["tokenization_bert_tf"] = ["TFBertTokenizer"]
|
95 |
+
|
96 |
+
try:
|
97 |
+
if not is_flax_available():
|
98 |
+
raise OptionalDependencyNotAvailable()
|
99 |
+
except OptionalDependencyNotAvailable:
|
100 |
+
pass
|
101 |
+
else:
|
102 |
+
_import_structure["modeling_flax_bert"] = [
|
103 |
+
"FlaxBertForCausalLM",
|
104 |
+
"FlaxBertForMaskedLM",
|
105 |
+
"FlaxBertForMultipleChoice",
|
106 |
+
"FlaxBertForNextSentencePrediction",
|
107 |
+
"FlaxBertForPreTraining",
|
108 |
+
"FlaxBertForQuestionAnswering",
|
109 |
+
"FlaxBertForSequenceClassification",
|
110 |
+
"FlaxBertForTokenClassification",
|
111 |
+
"FlaxBertModel",
|
112 |
+
"FlaxBertPreTrainedModel",
|
113 |
+
]
|
114 |
+
|
115 |
+
if TYPE_CHECKING:
|
116 |
+
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
|
117 |
+
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
118 |
+
|
119 |
+
try:
|
120 |
+
if not is_tokenizers_available():
|
121 |
+
raise OptionalDependencyNotAvailable()
|
122 |
+
except OptionalDependencyNotAvailable:
|
123 |
+
pass
|
124 |
+
else:
|
125 |
+
from .tokenization_bert_fast import BertTokenizerFast
|
126 |
+
|
127 |
+
try:
|
128 |
+
if not is_torch_available():
|
129 |
+
raise OptionalDependencyNotAvailable()
|
130 |
+
except OptionalDependencyNotAvailable:
|
131 |
+
pass
|
132 |
+
else:
|
133 |
+
from .modeling_bert import (
|
134 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
135 |
+
BertForMaskedLM,
|
136 |
+
BertForMultipleChoice,
|
137 |
+
BertForNextSentencePrediction,
|
138 |
+
BertForPreTraining,
|
139 |
+
BertForQuestionAnswering,
|
140 |
+
BertForSequenceClassification,
|
141 |
+
BertForTokenClassification,
|
142 |
+
BertLayer,
|
143 |
+
BertLMHeadModel,
|
144 |
+
BertModel,
|
145 |
+
BertPreTrainedModel,
|
146 |
+
load_tf_weights_in_bert,
|
147 |
+
)
|
148 |
+
|
149 |
+
try:
|
150 |
+
if not is_tf_available():
|
151 |
+
raise OptionalDependencyNotAvailable()
|
152 |
+
except OptionalDependencyNotAvailable:
|
153 |
+
pass
|
154 |
+
else:
|
155 |
+
from .modeling_tf_bert import (
|
156 |
+
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
157 |
+
TFBertEmbeddings,
|
158 |
+
TFBertForMaskedLM,
|
159 |
+
TFBertForMultipleChoice,
|
160 |
+
TFBertForNextSentencePrediction,
|
161 |
+
TFBertForPreTraining,
|
162 |
+
TFBertForQuestionAnswering,
|
163 |
+
TFBertForSequenceClassification,
|
164 |
+
TFBertForTokenClassification,
|
165 |
+
TFBertLMHeadModel,
|
166 |
+
TFBertMainLayer,
|
167 |
+
TFBertModel,
|
168 |
+
TFBertPreTrainedModel,
|
169 |
+
)
|
170 |
+
|
171 |
+
try:
|
172 |
+
if not is_tensorflow_text_available():
|
173 |
+
raise OptionalDependencyNotAvailable()
|
174 |
+
except OptionalDependencyNotAvailable:
|
175 |
+
pass
|
176 |
+
else:
|
177 |
+
from .tokenization_bert_tf import TFBertTokenizer
|
178 |
+
|
179 |
+
try:
|
180 |
+
if not is_flax_available():
|
181 |
+
raise OptionalDependencyNotAvailable()
|
182 |
+
except OptionalDependencyNotAvailable:
|
183 |
+
pass
|
184 |
+
else:
|
185 |
+
from .modeling_flax_bert import (
|
186 |
+
FlaxBertForCausalLM,
|
187 |
+
FlaxBertForMaskedLM,
|
188 |
+
FlaxBertForMultipleChoice,
|
189 |
+
FlaxBertForNextSentencePrediction,
|
190 |
+
FlaxBertForPreTraining,
|
191 |
+
FlaxBertForQuestionAnswering,
|
192 |
+
FlaxBertForSequenceClassification,
|
193 |
+
FlaxBertForTokenClassification,
|
194 |
+
FlaxBertModel,
|
195 |
+
FlaxBertPreTrainedModel,
|
196 |
+
)
|
197 |
+
|
198 |
+
else:
|
199 |
+
import sys
|
200 |
+
|
201 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
KTeleBERT/model/bert/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.22 kB). View file
|
|
KTeleBERT/model/bert/__pycache__/configuration_bert.cpython-38.pyc
ADDED
Binary file (8.86 kB). View file
|
|
KTeleBERT/model/bert/__pycache__/modeling_bert.cpython-38.pyc
ADDED
Binary file (57.4 kB). View file
|
|
KTeleBERT/model/bert/__pycache__/tokenization_bert.cpython-38.pyc
ADDED
Binary file (19 kB). View file
|
|
KTeleBERT/model/bert/configuration_bert.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" BERT model configuration"""
|
17 |
+
from collections import OrderedDict
|
18 |
+
from typing import Mapping
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.onnx import OnnxConfig
|
22 |
+
from transformers.utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
28 |
+
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
|
29 |
+
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
|
30 |
+
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
|
31 |
+
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
|
32 |
+
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
|
33 |
+
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
|
34 |
+
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
|
35 |
+
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
|
36 |
+
"bert-large-uncased-whole-word-masking": (
|
37 |
+
"https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json"
|
38 |
+
),
|
39 |
+
"bert-large-cased-whole-word-masking": (
|
40 |
+
"https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json"
|
41 |
+
),
|
42 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": (
|
43 |
+
"https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json"
|
44 |
+
),
|
45 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": (
|
46 |
+
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json"
|
47 |
+
),
|
48 |
+
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
|
49 |
+
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
|
50 |
+
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
|
51 |
+
"cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
|
52 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking": (
|
53 |
+
"https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json"
|
54 |
+
),
|
55 |
+
"cl-tohoku/bert-base-japanese-char": (
|
56 |
+
"https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json"
|
57 |
+
),
|
58 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking": (
|
59 |
+
"https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json"
|
60 |
+
),
|
61 |
+
"TurkuNLP/bert-base-finnish-cased-v1": (
|
62 |
+
"https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json"
|
63 |
+
),
|
64 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": (
|
65 |
+
"https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json"
|
66 |
+
),
|
67 |
+
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
|
68 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
class BertConfig(PretrainedConfig):
|
73 |
+
r"""
|
74 |
+
This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
|
75 |
+
instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
|
76 |
+
configuration with the defaults will yield a similar configuration to that of the BERT
|
77 |
+
[bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
|
78 |
+
|
79 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
80 |
+
documentation from [`PretrainedConfig`] for more information.
|
81 |
+
|
82 |
+
|
83 |
+
Args:
|
84 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
85 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
86 |
+
`inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
|
87 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
88 |
+
Dimensionality of the encoder layers and the pooler layer.
|
89 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
90 |
+
Number of hidden layers in the Transformer encoder.
|
91 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
92 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
93 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
94 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
95 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
96 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
97 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
98 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
99 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
100 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
101 |
+
The dropout ratio for the attention probabilities.
|
102 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
103 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
104 |
+
just in case (e.g., 512 or 1024 or 2048).
|
105 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
106 |
+
The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
|
107 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
108 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
109 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
110 |
+
The epsilon used by the layer normalization layers.
|
111 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
112 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
113 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
114 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
115 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
116 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
117 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
118 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
119 |
+
relevant if `config.is_decoder=True`.
|
120 |
+
classifier_dropout (`float`, *optional*):
|
121 |
+
The dropout ratio for the classification head.
|
122 |
+
|
123 |
+
Examples:
|
124 |
+
|
125 |
+
```python
|
126 |
+
>>> from transformers import BertModel, BertConfig
|
127 |
+
|
128 |
+
>>> # Initializing a BERT bert-base-uncased style configuration
|
129 |
+
>>> configuration = BertConfig()
|
130 |
+
|
131 |
+
>>> # Initializing a model from the bert-base-uncased style configuration
|
132 |
+
>>> model = BertModel(configuration)
|
133 |
+
|
134 |
+
>>> # Accessing the model configuration
|
135 |
+
>>> configuration = model.config
|
136 |
+
```"""
|
137 |
+
model_type = "bert"
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
vocab_size=30522,
|
142 |
+
hidden_size=768,
|
143 |
+
num_hidden_layers=12,
|
144 |
+
num_attention_heads=12,
|
145 |
+
intermediate_size=3072,
|
146 |
+
hidden_act="gelu",
|
147 |
+
hidden_dropout_prob=0.1,
|
148 |
+
attention_probs_dropout_prob=0.1,
|
149 |
+
max_position_embeddings=512,
|
150 |
+
type_vocab_size=2,
|
151 |
+
initializer_range=0.02,
|
152 |
+
layer_norm_eps=1e-12,
|
153 |
+
pad_token_id=0,
|
154 |
+
position_embedding_type="absolute",
|
155 |
+
use_cache=True,
|
156 |
+
classifier_dropout=None,
|
157 |
+
**kwargs
|
158 |
+
):
|
159 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
160 |
+
|
161 |
+
self.vocab_size = vocab_size
|
162 |
+
self.hidden_size = hidden_size
|
163 |
+
self.num_hidden_layers = num_hidden_layers
|
164 |
+
self.num_attention_heads = num_attention_heads
|
165 |
+
self.hidden_act = hidden_act
|
166 |
+
self.intermediate_size = intermediate_size
|
167 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
168 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
169 |
+
self.max_position_embeddings = max_position_embeddings
|
170 |
+
self.type_vocab_size = type_vocab_size
|
171 |
+
self.initializer_range = initializer_range
|
172 |
+
self.layer_norm_eps = layer_norm_eps
|
173 |
+
self.position_embedding_type = position_embedding_type
|
174 |
+
self.use_cache = use_cache
|
175 |
+
self.classifier_dropout = classifier_dropout
|
176 |
+
|
177 |
+
|
178 |
+
class BertOnnxConfig(OnnxConfig):
|
179 |
+
@property
|
180 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
181 |
+
if self.task == "multiple-choice":
|
182 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
183 |
+
else:
|
184 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
185 |
+
return OrderedDict(
|
186 |
+
[
|
187 |
+
("input_ids", dynamic_axis),
|
188 |
+
("attention_mask", dynamic_axis),
|
189 |
+
("token_type_ids", dynamic_axis),
|
190 |
+
]
|
191 |
+
)
|
KTeleBERT/model/bert/modeling_bert.py
ADDED
@@ -0,0 +1,2010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model."""
|
17 |
+
|
18 |
+
import pdb
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
import warnings
|
22 |
+
from dataclasses import dataclass
|
23 |
+
from typing import List, Optional, Tuple, Union
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from packaging import version
|
28 |
+
from torch import nn
|
29 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
30 |
+
|
31 |
+
from transformers.activations import ACT2FN
|
32 |
+
|
33 |
+
from transformers.modeling_outputs import (
|
34 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
35 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
36 |
+
CausalLMOutputWithCrossAttentions,
|
37 |
+
MaskedLMOutput,
|
38 |
+
MultipleChoiceModelOutput,
|
39 |
+
NextSentencePredictorOutput,
|
40 |
+
QuestionAnsweringModelOutput,
|
41 |
+
SequenceClassifierOutput,
|
42 |
+
TokenClassifierOutput,
|
43 |
+
)
|
44 |
+
from transformers.modeling_utils import PreTrainedModel
|
45 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
46 |
+
from transformers.utils import (
|
47 |
+
ModelOutput,
|
48 |
+
add_code_sample_docstrings,
|
49 |
+
add_start_docstrings,
|
50 |
+
add_start_docstrings_to_model_forward,
|
51 |
+
logging,
|
52 |
+
replace_return_docstrings,
|
53 |
+
)
|
54 |
+
from .configuration_bert import BertConfig
|
55 |
+
|
56 |
+
|
57 |
+
logger = logging.get_logger(__name__)
|
58 |
+
|
59 |
+
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
60 |
+
_CONFIG_FOR_DOC = "BertConfig"
|
61 |
+
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
62 |
+
|
63 |
+
# TokenClassification docstring
|
64 |
+
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
65 |
+
_TOKEN_CLASS_EXPECTED_OUTPUT = (
|
66 |
+
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
|
67 |
+
)
|
68 |
+
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
|
69 |
+
|
70 |
+
# QuestionAnswering docstring
|
71 |
+
_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
|
72 |
+
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
|
73 |
+
_QA_EXPECTED_LOSS = 7.41
|
74 |
+
_QA_TARGET_START_INDEX = 14
|
75 |
+
_QA_TARGET_END_INDEX = 15
|
76 |
+
|
77 |
+
# SequenceClassification docstring
|
78 |
+
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
|
79 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
80 |
+
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
81 |
+
|
82 |
+
|
83 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
84 |
+
"bert-base-uncased",
|
85 |
+
"bert-large-uncased",
|
86 |
+
"bert-base-cased",
|
87 |
+
"bert-large-cased",
|
88 |
+
"bert-base-multilingual-uncased",
|
89 |
+
"bert-base-multilingual-cased",
|
90 |
+
"bert-base-chinese",
|
91 |
+
"bert-base-german-cased",
|
92 |
+
"bert-large-uncased-whole-word-masking",
|
93 |
+
"bert-large-cased-whole-word-masking",
|
94 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
95 |
+
"bert-large-cased-whole-word-masking-finetuned-squad",
|
96 |
+
"bert-base-cased-finetuned-mrpc",
|
97 |
+
"bert-base-german-dbmdz-cased",
|
98 |
+
"bert-base-german-dbmdz-uncased",
|
99 |
+
"cl-tohoku/bert-base-japanese",
|
100 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
101 |
+
"cl-tohoku/bert-base-japanese-char",
|
102 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
103 |
+
"TurkuNLP/bert-base-finnish-cased-v1",
|
104 |
+
"TurkuNLP/bert-base-finnish-uncased-v1",
|
105 |
+
"wietsedv/bert-base-dutch-cased",
|
106 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
107 |
+
]
|
108 |
+
|
109 |
+
|
110 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
111 |
+
"""Load tf checkpoints in a pytorch model."""
|
112 |
+
try:
|
113 |
+
import re
|
114 |
+
|
115 |
+
import numpy as np
|
116 |
+
import tensorflow as tf
|
117 |
+
except ImportError:
|
118 |
+
logger.error(
|
119 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
120 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
121 |
+
)
|
122 |
+
raise
|
123 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
124 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
125 |
+
# Load weights from TF model
|
126 |
+
init_vars = tf.train.list_variables(tf_path)
|
127 |
+
names = []
|
128 |
+
arrays = []
|
129 |
+
for name, shape in init_vars:
|
130 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
131 |
+
array = tf.train.load_variable(tf_path, name)
|
132 |
+
names.append(name)
|
133 |
+
arrays.append(array)
|
134 |
+
|
135 |
+
for name, array in zip(names, arrays):
|
136 |
+
name = name.split("/")
|
137 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
138 |
+
# which are not required for using pretrained model
|
139 |
+
if any(
|
140 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
141 |
+
for n in name
|
142 |
+
):
|
143 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
144 |
+
continue
|
145 |
+
pointer = model
|
146 |
+
for m_name in name:
|
147 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
148 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
149 |
+
else:
|
150 |
+
scope_names = [m_name]
|
151 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
152 |
+
pointer = getattr(pointer, "weight")
|
153 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
154 |
+
pointer = getattr(pointer, "bias")
|
155 |
+
elif scope_names[0] == "output_weights":
|
156 |
+
pointer = getattr(pointer, "weight")
|
157 |
+
elif scope_names[0] == "squad":
|
158 |
+
pointer = getattr(pointer, "classifier")
|
159 |
+
else:
|
160 |
+
try:
|
161 |
+
pointer = getattr(pointer, scope_names[0])
|
162 |
+
except AttributeError:
|
163 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
164 |
+
continue
|
165 |
+
if len(scope_names) >= 2:
|
166 |
+
num = int(scope_names[1])
|
167 |
+
pointer = pointer[num]
|
168 |
+
if m_name[-11:] == "_embeddings":
|
169 |
+
pointer = getattr(pointer, "weight")
|
170 |
+
elif m_name == "kernel":
|
171 |
+
array = np.transpose(array)
|
172 |
+
try:
|
173 |
+
if pointer.shape != array.shape:
|
174 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
175 |
+
except AssertionError as e:
|
176 |
+
e.args += (pointer.shape, array.shape)
|
177 |
+
raise
|
178 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
179 |
+
pointer.data = torch.from_numpy(array)
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
class BertEmbeddings(nn.Module):
|
184 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
185 |
+
|
186 |
+
def __init__(self, config):
|
187 |
+
super().__init__()
|
188 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
189 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
190 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
191 |
+
|
192 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
193 |
+
# any TensorFlow checkpoint file
|
194 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
195 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
196 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
197 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
198 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
199 |
+
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
200 |
+
self.register_buffer(
|
201 |
+
"token_type_ids",
|
202 |
+
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
203 |
+
persistent=False,
|
204 |
+
)
|
205 |
+
|
206 |
+
def forward(
|
207 |
+
self,
|
208 |
+
input_ids: Optional[torch.LongTensor] = None,
|
209 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
210 |
+
position_ids: Optional[torch.LongTensor] = None,
|
211 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
212 |
+
past_key_values_length: int = 0,
|
213 |
+
kpi_ref = None, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
|
214 |
+
kpi_model = None,
|
215 |
+
) -> torch.Tensor:
|
216 |
+
if input_ids is not None:
|
217 |
+
input_shape = input_ids.size()
|
218 |
+
else:
|
219 |
+
input_shape = inputs_embeds.size()[:-1]
|
220 |
+
|
221 |
+
seq_length = input_shape[1]
|
222 |
+
|
223 |
+
if position_ids is None:
|
224 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
225 |
+
|
226 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
227 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
228 |
+
# issue #5664
|
229 |
+
if token_type_ids is None:
|
230 |
+
if hasattr(self, "token_type_ids"):
|
231 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
232 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
233 |
+
token_type_ids = buffered_token_type_ids_expanded
|
234 |
+
else:
|
235 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
236 |
+
|
237 |
+
|
238 |
+
if inputs_embeds is None:
|
239 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
240 |
+
|
241 |
+
# pdb.set_trace()
|
242 |
+
# TODO 得到 KPI的name embedding,pooling,输入数值编码模型,得到特征替换(mask+)特定位置原向量,
|
243 |
+
# 不产生新的embedding,直接读取生成的embedding
|
244 |
+
en_loss, scalar_list, con_loss, numeric_input, kpi_input = None, None, None, None, None
|
245 |
+
if kpi_ref is not None:
|
246 |
+
max_len = inputs_embeds.shape[1]
|
247 |
+
# 生成数值embedding
|
248 |
+
numeric_list = []
|
249 |
+
kpi_emb_list = []
|
250 |
+
kpi_id_list = []
|
251 |
+
for i in range(len(kpi_ref)):
|
252 |
+
if len(kpi_ref[i])>0:
|
253 |
+
for item in kpi_ref[i]:
|
254 |
+
# 可能[NUM]被截断了
|
255 |
+
if item[2]>=max_len:
|
256 |
+
continue
|
257 |
+
numeric_list.append(item[4])
|
258 |
+
kpi_id_list.append(item[3])
|
259 |
+
# requires_grad=True
|
260 |
+
kpi_name_embedding = torch.mean(inputs_embeds[i][item[0]:item[1]+1], dim=0)
|
261 |
+
kpi_emb_list.append(kpi_name_embedding)
|
262 |
+
# 有可能出现没有KPI的情况
|
263 |
+
if len(kpi_emb_list)>0:
|
264 |
+
kpi_emb = torch.stack(kpi_emb_list).unsqueeze(1)
|
265 |
+
|
266 |
+
# , dtype=torch.float64
|
267 |
+
numeric_input = torch.Tensor(numeric_list).unsqueeze(1).unsqueeze(1).cuda()
|
268 |
+
kpi_input = torch.tensor(kpi_id_list, dtype=torch.long).cuda()
|
269 |
+
# pdb.set_trace()
|
270 |
+
hidden, en_loss, scalar_list, con_loss = kpi_model._encode(numeric_input, kpi_emb)
|
271 |
+
# 替换
|
272 |
+
key = 0
|
273 |
+
for i in range(len(kpi_ref)):
|
274 |
+
if len(kpi_ref[i])>0:
|
275 |
+
for item in kpi_ref[i]:
|
276 |
+
if item[2]>=max_len:
|
277 |
+
continue
|
278 |
+
# [NUM]的(x,y)坐标位置
|
279 |
+
inputs_embeds[i,item[2]] = hidden[key][0]
|
280 |
+
key += 1
|
281 |
+
assert key == hidden.shape[0]
|
282 |
+
|
283 |
+
|
284 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
285 |
+
|
286 |
+
embeddings = inputs_embeds + token_type_embeddings
|
287 |
+
if self.position_embedding_type == "absolute":
|
288 |
+
position_embeddings = self.position_embeddings(position_ids)
|
289 |
+
embeddings += position_embeddings
|
290 |
+
embeddings = self.LayerNorm(embeddings)
|
291 |
+
embeddings = self.dropout(embeddings)
|
292 |
+
# 重构输出
|
293 |
+
return embeddings, en_loss, scalar_list, con_loss, numeric_input, kpi_input
|
294 |
+
# return embeddings
|
295 |
+
|
296 |
+
|
297 |
+
class BertSelfAttention(nn.Module):
|
298 |
+
def __init__(self, config, position_embedding_type=None):
|
299 |
+
super().__init__()
|
300 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
301 |
+
raise ValueError(
|
302 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
303 |
+
f"heads ({config.num_attention_heads})"
|
304 |
+
)
|
305 |
+
|
306 |
+
self.num_attention_heads = config.num_attention_heads
|
307 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
308 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
309 |
+
|
310 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
311 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
312 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
313 |
+
|
314 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
315 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
316 |
+
config, "position_embedding_type", "absolute"
|
317 |
+
)
|
318 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
319 |
+
self.max_position_embeddings = config.max_position_embeddings
|
320 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
321 |
+
|
322 |
+
self.is_decoder = config.is_decoder
|
323 |
+
|
324 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
325 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
326 |
+
x = x.view(new_x_shape)
|
327 |
+
return x.permute(0, 2, 1, 3)
|
328 |
+
|
329 |
+
def forward(
|
330 |
+
self,
|
331 |
+
hidden_states: torch.Tensor,
|
332 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
333 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
334 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
335 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
336 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
337 |
+
output_attentions: Optional[bool] = False,
|
338 |
+
) -> Tuple[torch.Tensor]:
|
339 |
+
mixed_query_layer = self.query(hidden_states)
|
340 |
+
|
341 |
+
# If this is instantiated as a cross-attention module, the keys
|
342 |
+
# and values come from an encoder; the attention mask needs to be
|
343 |
+
# such that the encoder's padding tokens are not attended to.
|
344 |
+
is_cross_attention = encoder_hidden_states is not None
|
345 |
+
|
346 |
+
if is_cross_attention and past_key_value is not None:
|
347 |
+
# reuse k,v, cross_attentions
|
348 |
+
key_layer = past_key_value[0]
|
349 |
+
value_layer = past_key_value[1]
|
350 |
+
attention_mask = encoder_attention_mask
|
351 |
+
elif is_cross_attention:
|
352 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
353 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
354 |
+
attention_mask = encoder_attention_mask
|
355 |
+
elif past_key_value is not None:
|
356 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
357 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
358 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
359 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
360 |
+
else:
|
361 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
362 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
363 |
+
|
364 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
365 |
+
|
366 |
+
if self.is_decoder:
|
367 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
368 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
369 |
+
# key/value_states (first "if" case)
|
370 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
371 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
372 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
373 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
374 |
+
past_key_value = (key_layer, value_layer)
|
375 |
+
|
376 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
377 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
378 |
+
|
379 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
380 |
+
seq_length = hidden_states.size()[1]
|
381 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
382 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
383 |
+
distance = position_ids_l - position_ids_r
|
384 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
385 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
386 |
+
|
387 |
+
if self.position_embedding_type == "relative_key":
|
388 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
389 |
+
attention_scores = attention_scores + relative_position_scores
|
390 |
+
elif self.position_embedding_type == "relative_key_query":
|
391 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
392 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
393 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
394 |
+
|
395 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
396 |
+
if attention_mask is not None:
|
397 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
398 |
+
attention_scores = attention_scores + attention_mask
|
399 |
+
|
400 |
+
# Normalize the attention scores to probabilities.
|
401 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
402 |
+
|
403 |
+
# This is actually dropping out entire tokens to attend to, which might
|
404 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
405 |
+
attention_probs = self.dropout(attention_probs)
|
406 |
+
|
407 |
+
# Mask heads if we want to
|
408 |
+
if head_mask is not None:
|
409 |
+
attention_probs = attention_probs * head_mask
|
410 |
+
|
411 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
412 |
+
|
413 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
414 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
415 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
416 |
+
|
417 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
418 |
+
|
419 |
+
if self.is_decoder:
|
420 |
+
outputs = outputs + (past_key_value,)
|
421 |
+
return outputs
|
422 |
+
|
423 |
+
|
424 |
+
class BertSelfOutput(nn.Module):
|
425 |
+
def __init__(self, config):
|
426 |
+
super().__init__()
|
427 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
428 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
429 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
430 |
+
|
431 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
432 |
+
hidden_states = self.dense(hidden_states)
|
433 |
+
hidden_states = self.dropout(hidden_states)
|
434 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
435 |
+
return hidden_states
|
436 |
+
|
437 |
+
|
438 |
+
class BertAttention(nn.Module):
|
439 |
+
def __init__(self, config, position_embedding_type=None):
|
440 |
+
super().__init__()
|
441 |
+
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
442 |
+
self.output = BertSelfOutput(config)
|
443 |
+
self.pruned_heads = set()
|
444 |
+
|
445 |
+
def prune_heads(self, heads):
|
446 |
+
if len(heads) == 0:
|
447 |
+
return
|
448 |
+
heads, index = find_pruneable_heads_and_indices(
|
449 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
450 |
+
)
|
451 |
+
|
452 |
+
# Prune linear layers
|
453 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
454 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
455 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
456 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
457 |
+
|
458 |
+
# Update hyper params and store pruned heads
|
459 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
460 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
461 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
462 |
+
|
463 |
+
def forward(
|
464 |
+
self,
|
465 |
+
hidden_states: torch.Tensor,
|
466 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
467 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
468 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
469 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
470 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
471 |
+
output_attentions: Optional[bool] = False,
|
472 |
+
) -> Tuple[torch.Tensor]:
|
473 |
+
self_outputs = self.self(
|
474 |
+
hidden_states,
|
475 |
+
attention_mask,
|
476 |
+
head_mask,
|
477 |
+
encoder_hidden_states,
|
478 |
+
encoder_attention_mask,
|
479 |
+
past_key_value,
|
480 |
+
output_attentions,
|
481 |
+
)
|
482 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
483 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
484 |
+
return outputs
|
485 |
+
|
486 |
+
|
487 |
+
class BertIntermediate(nn.Module):
|
488 |
+
def __init__(self, config):
|
489 |
+
super().__init__()
|
490 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
491 |
+
if isinstance(config.hidden_act, str):
|
492 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
493 |
+
else:
|
494 |
+
self.intermediate_act_fn = config.hidden_act
|
495 |
+
|
496 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
497 |
+
hidden_states = self.dense(hidden_states)
|
498 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
499 |
+
return hidden_states
|
500 |
+
|
501 |
+
|
502 |
+
class BertOutput(nn.Module):
|
503 |
+
def __init__(self, config):
|
504 |
+
super().__init__()
|
505 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
506 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
507 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
508 |
+
|
509 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
510 |
+
hidden_states = self.dense(hidden_states)
|
511 |
+
hidden_states = self.dropout(hidden_states)
|
512 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
513 |
+
return hidden_states
|
514 |
+
|
515 |
+
|
516 |
+
class BertLayer(nn.Module):
|
517 |
+
def __init__(self, config):
|
518 |
+
super().__init__()
|
519 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
520 |
+
self.seq_len_dim = 1
|
521 |
+
self.attention = BertAttention(config)
|
522 |
+
self.is_decoder = config.is_decoder
|
523 |
+
self.add_cross_attention = config.add_cross_attention
|
524 |
+
if self.add_cross_attention:
|
525 |
+
if not self.is_decoder:
|
526 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
527 |
+
self.crossattention = BertAttention(config, position_embedding_type="absolute")
|
528 |
+
self.intermediate = BertIntermediate(config)
|
529 |
+
self.output = BertOutput(config)
|
530 |
+
|
531 |
+
def forward(
|
532 |
+
self,
|
533 |
+
hidden_states: torch.Tensor,
|
534 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
535 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
536 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
537 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
538 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
539 |
+
output_attentions: Optional[bool] = False,
|
540 |
+
) -> Tuple[torch.Tensor]:
|
541 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
542 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
543 |
+
self_attention_outputs = self.attention(
|
544 |
+
hidden_states,
|
545 |
+
attention_mask,
|
546 |
+
head_mask,
|
547 |
+
output_attentions=output_attentions,
|
548 |
+
past_key_value=self_attn_past_key_value,
|
549 |
+
)
|
550 |
+
attention_output = self_attention_outputs[0]
|
551 |
+
|
552 |
+
# if decoder, the last output is tuple of self-attn cache
|
553 |
+
if self.is_decoder:
|
554 |
+
outputs = self_attention_outputs[1:-1]
|
555 |
+
present_key_value = self_attention_outputs[-1]
|
556 |
+
else:
|
557 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
558 |
+
|
559 |
+
cross_attn_present_key_value = None
|
560 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
561 |
+
if not hasattr(self, "crossattention"):
|
562 |
+
raise ValueError(
|
563 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
564 |
+
" by setting `config.add_cross_attention=True`"
|
565 |
+
)
|
566 |
+
|
567 |
+
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
568 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
569 |
+
cross_attention_outputs = self.crossattention(
|
570 |
+
attention_output,
|
571 |
+
attention_mask,
|
572 |
+
head_mask,
|
573 |
+
encoder_hidden_states,
|
574 |
+
encoder_attention_mask,
|
575 |
+
cross_attn_past_key_value,
|
576 |
+
output_attentions,
|
577 |
+
)
|
578 |
+
attention_output = cross_attention_outputs[0]
|
579 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
580 |
+
|
581 |
+
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
582 |
+
cross_attn_present_key_value = cross_attention_outputs[-1]
|
583 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
584 |
+
|
585 |
+
layer_output = apply_chunking_to_forward(
|
586 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
587 |
+
)
|
588 |
+
outputs = (layer_output,) + outputs
|
589 |
+
|
590 |
+
# if decoder, return the attn key/values as the last output
|
591 |
+
if self.is_decoder:
|
592 |
+
outputs = outputs + (present_key_value,)
|
593 |
+
|
594 |
+
return outputs
|
595 |
+
|
596 |
+
def feed_forward_chunk(self, attention_output):
|
597 |
+
intermediate_output = self.intermediate(attention_output)
|
598 |
+
layer_output = self.output(intermediate_output, attention_output)
|
599 |
+
return layer_output
|
600 |
+
|
601 |
+
|
602 |
+
class BertEncoder(nn.Module):
|
603 |
+
def __init__(self, config):
|
604 |
+
super().__init__()
|
605 |
+
self.config = config
|
606 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
607 |
+
self.gradient_checkpointing = False
|
608 |
+
|
609 |
+
def forward(
|
610 |
+
self,
|
611 |
+
hidden_states: torch.Tensor,
|
612 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
613 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
614 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
615 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
616 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
617 |
+
use_cache: Optional[bool] = None,
|
618 |
+
output_attentions: Optional[bool] = False,
|
619 |
+
output_hidden_states: Optional[bool] = False,
|
620 |
+
return_dict: Optional[bool] = True,
|
621 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
622 |
+
all_hidden_states = () if output_hidden_states else None
|
623 |
+
all_self_attentions = () if output_attentions else None
|
624 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
625 |
+
|
626 |
+
next_decoder_cache = () if use_cache else None
|
627 |
+
for i, layer_module in enumerate(self.layer):
|
628 |
+
if output_hidden_states:
|
629 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
630 |
+
|
631 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
632 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
633 |
+
|
634 |
+
if self.gradient_checkpointing and self.training:
|
635 |
+
|
636 |
+
if use_cache:
|
637 |
+
logger.warning(
|
638 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
639 |
+
)
|
640 |
+
use_cache = False
|
641 |
+
|
642 |
+
def create_custom_forward(module):
|
643 |
+
def custom_forward(*inputs):
|
644 |
+
return module(*inputs, past_key_value, output_attentions)
|
645 |
+
|
646 |
+
return custom_forward
|
647 |
+
|
648 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
649 |
+
create_custom_forward(layer_module),
|
650 |
+
hidden_states,
|
651 |
+
attention_mask,
|
652 |
+
layer_head_mask,
|
653 |
+
encoder_hidden_states,
|
654 |
+
encoder_attention_mask,
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
layer_outputs = layer_module(
|
658 |
+
hidden_states,
|
659 |
+
attention_mask,
|
660 |
+
layer_head_mask,
|
661 |
+
encoder_hidden_states,
|
662 |
+
encoder_attention_mask,
|
663 |
+
past_key_value,
|
664 |
+
output_attentions,
|
665 |
+
)
|
666 |
+
|
667 |
+
hidden_states = layer_outputs[0]
|
668 |
+
if use_cache:
|
669 |
+
next_decoder_cache += (layer_outputs[-1],)
|
670 |
+
if output_attentions:
|
671 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
672 |
+
if self.config.add_cross_attention:
|
673 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
674 |
+
|
675 |
+
if output_hidden_states:
|
676 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
677 |
+
|
678 |
+
if not return_dict:
|
679 |
+
return tuple(
|
680 |
+
v
|
681 |
+
for v in [
|
682 |
+
hidden_states,
|
683 |
+
next_decoder_cache,
|
684 |
+
all_hidden_states,
|
685 |
+
all_self_attentions,
|
686 |
+
all_cross_attentions,
|
687 |
+
]
|
688 |
+
if v is not None
|
689 |
+
)
|
690 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
691 |
+
last_hidden_state=hidden_states,
|
692 |
+
past_key_values=next_decoder_cache,
|
693 |
+
hidden_states=all_hidden_states,
|
694 |
+
attentions=all_self_attentions,
|
695 |
+
cross_attentions=all_cross_attentions,
|
696 |
+
)
|
697 |
+
|
698 |
+
|
699 |
+
class BertPooler(nn.Module):
|
700 |
+
def __init__(self, config):
|
701 |
+
super().__init__()
|
702 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
703 |
+
self.activation = nn.Tanh()
|
704 |
+
|
705 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
706 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
707 |
+
# to the first token.
|
708 |
+
first_token_tensor = hidden_states[:, 0]
|
709 |
+
pooled_output = self.dense(first_token_tensor)
|
710 |
+
pooled_output = self.activation(pooled_output)
|
711 |
+
return pooled_output
|
712 |
+
|
713 |
+
|
714 |
+
class BertPredictionHeadTransform(nn.Module):
|
715 |
+
def __init__(self, config):
|
716 |
+
super().__init__()
|
717 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
718 |
+
if isinstance(config.hidden_act, str):
|
719 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
720 |
+
else:
|
721 |
+
self.transform_act_fn = config.hidden_act
|
722 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
723 |
+
|
724 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
725 |
+
hidden_states = self.dense(hidden_states)
|
726 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
727 |
+
hidden_states = self.LayerNorm(hidden_states)
|
728 |
+
return hidden_states
|
729 |
+
|
730 |
+
|
731 |
+
class BertLMPredictionHead(nn.Module):
|
732 |
+
def __init__(self, config):
|
733 |
+
super().__init__()
|
734 |
+
self.transform = BertPredictionHeadTransform(config)
|
735 |
+
|
736 |
+
# The output weights are the same as the input embeddings, but there is
|
737 |
+
# an output-only bias for each token.
|
738 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
739 |
+
|
740 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
741 |
+
|
742 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
743 |
+
self.decoder.bias = self.bias
|
744 |
+
|
745 |
+
def forward(self, hidden_states):
|
746 |
+
hidden_states = self.transform(hidden_states)
|
747 |
+
hidden_states = self.decoder(hidden_states)
|
748 |
+
return hidden_states
|
749 |
+
|
750 |
+
|
751 |
+
class BertOnlyMLMHead(nn.Module):
|
752 |
+
def __init__(self, config):
|
753 |
+
super().__init__()
|
754 |
+
self.predictions = BertLMPredictionHead(config)
|
755 |
+
|
756 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
757 |
+
prediction_scores = self.predictions(sequence_output)
|
758 |
+
return prediction_scores
|
759 |
+
|
760 |
+
|
761 |
+
class BertOnlyNSPHead(nn.Module):
|
762 |
+
def __init__(self, config):
|
763 |
+
super().__init__()
|
764 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
765 |
+
|
766 |
+
def forward(self, pooled_output):
|
767 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
768 |
+
return seq_relationship_score
|
769 |
+
|
770 |
+
|
771 |
+
class BertPreTrainingHeads(nn.Module):
|
772 |
+
def __init__(self, config):
|
773 |
+
super().__init__()
|
774 |
+
self.predictions = BertLMPredictionHead(config)
|
775 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
776 |
+
|
777 |
+
def forward(self, sequence_output, pooled_output):
|
778 |
+
prediction_scores = self.predictions(sequence_output)
|
779 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
780 |
+
return prediction_scores, seq_relationship_score
|
781 |
+
|
782 |
+
|
783 |
+
class BertPreTrainedModel(PreTrainedModel):
|
784 |
+
"""
|
785 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
786 |
+
models.
|
787 |
+
"""
|
788 |
+
|
789 |
+
config_class = BertConfig
|
790 |
+
load_tf_weights = load_tf_weights_in_bert
|
791 |
+
base_model_prefix = "bert"
|
792 |
+
supports_gradient_checkpointing = True
|
793 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
794 |
+
|
795 |
+
def _init_weights(self, module):
|
796 |
+
"""Initialize the weights"""
|
797 |
+
if isinstance(module, nn.Linear):
|
798 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
799 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
800 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
801 |
+
if module.bias is not None:
|
802 |
+
module.bias.data.zero_()
|
803 |
+
elif isinstance(module, nn.Embedding):
|
804 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
805 |
+
if module.padding_idx is not None:
|
806 |
+
module.weight.data[module.padding_idx].zero_()
|
807 |
+
elif isinstance(module, nn.LayerNorm):
|
808 |
+
module.bias.data.zero_()
|
809 |
+
module.weight.data.fill_(1.0)
|
810 |
+
|
811 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
812 |
+
if isinstance(module, BertEncoder):
|
813 |
+
module.gradient_checkpointing = value
|
814 |
+
|
815 |
+
|
816 |
+
@dataclass
|
817 |
+
class BertForPreTrainingOutput(ModelOutput):
|
818 |
+
"""
|
819 |
+
Output type of [`BertForPreTraining`].
|
820 |
+
|
821 |
+
Args:
|
822 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
823 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
824 |
+
(classification) loss.
|
825 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
826 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
827 |
+
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
828 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
829 |
+
before SoftMax).
|
830 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
831 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
832 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
833 |
+
|
834 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
835 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
836 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
837 |
+
sequence_length)`.
|
838 |
+
|
839 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
840 |
+
heads.
|
841 |
+
"""
|
842 |
+
|
843 |
+
loss: Optional[torch.FloatTensor] = None
|
844 |
+
prediction_logits: torch.FloatTensor = None
|
845 |
+
seq_relationship_logits: torch.FloatTensor = None
|
846 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
847 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
848 |
+
|
849 |
+
|
850 |
+
BERT_START_DOCSTRING = r"""
|
851 |
+
|
852 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
853 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
854 |
+
etc.)
|
855 |
+
|
856 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
857 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
858 |
+
and behavior.
|
859 |
+
|
860 |
+
Parameters:
|
861 |
+
config ([`BertConfig`]): Model configuration class with all the parameters of the model.
|
862 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
863 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
864 |
+
"""
|
865 |
+
|
866 |
+
BERT_INPUTS_DOCSTRING = r"""
|
867 |
+
Args:
|
868 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
869 |
+
Indices of input sequence tokens in the vocabulary.
|
870 |
+
|
871 |
+
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
872 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
873 |
+
|
874 |
+
[What are input IDs?](../glossary#input-ids)
|
875 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
876 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
877 |
+
|
878 |
+
- 1 for tokens that are **not masked**,
|
879 |
+
- 0 for tokens that are **masked**.
|
880 |
+
|
881 |
+
[What are attention masks?](../glossary#attention-mask)
|
882 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
883 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
884 |
+
1]`:
|
885 |
+
|
886 |
+
- 0 corresponds to a *sentence A* token,
|
887 |
+
- 1 corresponds to a *sentence B* token.
|
888 |
+
|
889 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
890 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
891 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
892 |
+
config.max_position_embeddings - 1]`.
|
893 |
+
|
894 |
+
[What are position IDs?](../glossary#position-ids)
|
895 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
896 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
897 |
+
|
898 |
+
- 1 indicates the head is **not masked**,
|
899 |
+
- 0 indicates the head is **masked**.
|
900 |
+
|
901 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
902 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
903 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
904 |
+
model's internal embedding lookup matrix.
|
905 |
+
output_attentions (`bool`, *optional*):
|
906 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
907 |
+
tensors for more detail.
|
908 |
+
output_hidden_states (`bool`, *optional*):
|
909 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
910 |
+
more detail.
|
911 |
+
return_dict (`bool`, *optional*):
|
912 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
913 |
+
"""
|
914 |
+
|
915 |
+
|
916 |
+
@add_start_docstrings(
|
917 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
918 |
+
BERT_START_DOCSTRING,
|
919 |
+
)
|
920 |
+
class BertModel(BertPreTrainedModel):
|
921 |
+
"""
|
922 |
+
|
923 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
924 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
925 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
926 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
927 |
+
|
928 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
929 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
930 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
931 |
+
"""
|
932 |
+
|
933 |
+
def __init__(self, config, add_pooling_layer=True):
|
934 |
+
super().__init__(config)
|
935 |
+
self.config = config
|
936 |
+
|
937 |
+
self.embeddings = BertEmbeddings(config)
|
938 |
+
self.encoder = BertEncoder(config)
|
939 |
+
|
940 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
941 |
+
|
942 |
+
# Initialize weights and apply final processing
|
943 |
+
self.post_init()
|
944 |
+
|
945 |
+
def get_input_embeddings(self):
|
946 |
+
return self.embeddings.word_embeddings
|
947 |
+
|
948 |
+
def set_input_embeddings(self, value):
|
949 |
+
self.embeddings.word_embeddings = value
|
950 |
+
|
951 |
+
def _prune_heads(self, heads_to_prune):
|
952 |
+
"""
|
953 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
954 |
+
class PreTrainedModel
|
955 |
+
"""
|
956 |
+
for layer, heads in heads_to_prune.items():
|
957 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
958 |
+
|
959 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
960 |
+
@add_code_sample_docstrings(
|
961 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
962 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
963 |
+
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
964 |
+
config_class=_CONFIG_FOR_DOC,
|
965 |
+
)
|
966 |
+
def forward(
|
967 |
+
self,
|
968 |
+
input_ids: Optional[torch.Tensor] = None,
|
969 |
+
attention_mask: Optional[torch.Tensor] = None,
|
970 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
971 |
+
position_ids: Optional[torch.Tensor] = None,
|
972 |
+
head_mask: Optional[torch.Tensor] = None,
|
973 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
974 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
975 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
976 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
977 |
+
use_cache: Optional[bool] = None,
|
978 |
+
output_attentions: Optional[bool] = None,
|
979 |
+
output_hidden_states: Optional[bool] = None,
|
980 |
+
return_dict: Optional[bool] = None,
|
981 |
+
kpi_ref = None, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
|
982 |
+
kpi_model = None, # 输入KPI模型
|
983 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
984 |
+
r"""
|
985 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
986 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
987 |
+
the model is configured as a decoder.
|
988 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
989 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
990 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
991 |
+
|
992 |
+
- 1 for tokens that are **not masked**,
|
993 |
+
- 0 for tokens that are **masked**.
|
994 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
995 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
996 |
+
|
997 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
998 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
999 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1000 |
+
use_cache (`bool`, *optional*):
|
1001 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1002 |
+
`past_key_values`).
|
1003 |
+
"""
|
1004 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1005 |
+
output_hidden_states = (
|
1006 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1007 |
+
)
|
1008 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1009 |
+
|
1010 |
+
if self.config.is_decoder:
|
1011 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1012 |
+
else:
|
1013 |
+
use_cache = False
|
1014 |
+
|
1015 |
+
if input_ids is not None and inputs_embeds is not None:
|
1016 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1017 |
+
elif input_ids is not None:
|
1018 |
+
input_shape = input_ids.size()
|
1019 |
+
elif inputs_embeds is not None:
|
1020 |
+
input_shape = inputs_embeds.size()[:-1]
|
1021 |
+
else:
|
1022 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
1023 |
+
|
1024 |
+
batch_size, seq_length = input_shape
|
1025 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1026 |
+
|
1027 |
+
# past_key_values_length
|
1028 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1029 |
+
|
1030 |
+
if attention_mask is None:
|
1031 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
1032 |
+
|
1033 |
+
if token_type_ids is None:
|
1034 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
1035 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
1036 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
1037 |
+
token_type_ids = buffered_token_type_ids_expanded
|
1038 |
+
else:
|
1039 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
1040 |
+
|
1041 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
1042 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
1043 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
1044 |
+
|
1045 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
1046 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
1047 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
1048 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
1049 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
1050 |
+
if encoder_attention_mask is None:
|
1051 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
1052 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
1053 |
+
else:
|
1054 |
+
encoder_extended_attention_mask = None
|
1055 |
+
|
1056 |
+
# Prepare head mask if needed
|
1057 |
+
# 1.0 in head_mask indicate we keep the head
|
1058 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1059 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1060 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1061 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1062 |
+
|
1063 |
+
embedding_output, en_loss, scalar_list, con_loss, numeric_input, kpi_input = self.embeddings(
|
1064 |
+
input_ids=input_ids,
|
1065 |
+
position_ids=position_ids,
|
1066 |
+
token_type_ids=token_type_ids,
|
1067 |
+
inputs_embeds=inputs_embeds,
|
1068 |
+
past_key_values_length=past_key_values_length,
|
1069 |
+
kpi_ref=kpi_ref, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
|
1070 |
+
kpi_model=kpi_model,
|
1071 |
+
)
|
1072 |
+
|
1073 |
+
# 输入了占位符的位置信息
|
1074 |
+
# KPI的起始,结束位置embedding 的 pooling
|
1075 |
+
|
1076 |
+
# 在这里按位置替换数值embedding
|
1077 |
+
# 同时用KPI的 embedding 作为监督信号
|
1078 |
+
#
|
1079 |
+
encoder_outputs = self.encoder(
|
1080 |
+
embedding_output,
|
1081 |
+
attention_mask=extended_attention_mask,
|
1082 |
+
head_mask=head_mask,
|
1083 |
+
encoder_hidden_states=encoder_hidden_states,
|
1084 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
1085 |
+
past_key_values=past_key_values,
|
1086 |
+
use_cache=use_cache,
|
1087 |
+
output_attentions=output_attentions,
|
1088 |
+
output_hidden_states=output_hidden_states,
|
1089 |
+
return_dict=return_dict,
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
|
1093 |
+
# 在这里对数值embedding的位置做回归loss
|
1094 |
+
sequence_output = encoder_outputs[0]
|
1095 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1096 |
+
|
1097 |
+
if not return_dict:
|
1098 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
1099 |
+
|
1100 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1101 |
+
last_hidden_state=sequence_output,
|
1102 |
+
pooler_output=pooled_output,
|
1103 |
+
past_key_values=encoder_outputs.past_key_values,
|
1104 |
+
hidden_states=encoder_outputs.hidden_states,
|
1105 |
+
attentions=encoder_outputs.attentions,
|
1106 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
1107 |
+
), en_loss, scalar_list, con_loss, numeric_input, kpi_input
|
1108 |
+
|
1109 |
+
|
1110 |
+
@add_start_docstrings(
|
1111 |
+
"""
|
1112 |
+
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
1113 |
+
sentence prediction (classification)` head.
|
1114 |
+
""",
|
1115 |
+
BERT_START_DOCSTRING,
|
1116 |
+
)
|
1117 |
+
class BertForPreTraining(BertPreTrainedModel):
|
1118 |
+
def __init__(self, config):
|
1119 |
+
super().__init__(config)
|
1120 |
+
|
1121 |
+
self.bert = BertModel(config)
|
1122 |
+
self.cls = BertPreTrainingHeads(config)
|
1123 |
+
|
1124 |
+
# Initialize weights and apply final processing
|
1125 |
+
self.post_init()
|
1126 |
+
|
1127 |
+
def get_output_embeddings(self):
|
1128 |
+
return self.cls.predictions.decoder
|
1129 |
+
|
1130 |
+
def set_output_embeddings(self, new_embeddings):
|
1131 |
+
self.cls.predictions.decoder = new_embeddings
|
1132 |
+
|
1133 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1134 |
+
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
1135 |
+
def forward(
|
1136 |
+
self,
|
1137 |
+
input_ids: Optional[torch.Tensor] = None,
|
1138 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1139 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1140 |
+
position_ids: Optional[torch.Tensor] = None,
|
1141 |
+
head_mask: Optional[torch.Tensor] = None,
|
1142 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1143 |
+
labels: Optional[torch.Tensor] = None,
|
1144 |
+
next_sentence_label: Optional[torch.Tensor] = None,
|
1145 |
+
output_attentions: Optional[bool] = None,
|
1146 |
+
output_hidden_states: Optional[bool] = None,
|
1147 |
+
return_dict: Optional[bool] = None,
|
1148 |
+
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
|
1149 |
+
r"""
|
1150 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1151 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
1152 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
|
1153 |
+
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
1154 |
+
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1155 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
|
1156 |
+
pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
|
1157 |
+
|
1158 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
1159 |
+
- 1 indicates sequence B is a random sequence.
|
1160 |
+
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
1161 |
+
Used to hide legacy arguments that have been deprecated.
|
1162 |
+
|
1163 |
+
Returns:
|
1164 |
+
|
1165 |
+
Example:
|
1166 |
+
|
1167 |
+
```python
|
1168 |
+
>>> from transformers import BertTokenizer, BertForPreTraining
|
1169 |
+
>>> import torch
|
1170 |
+
|
1171 |
+
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
1172 |
+
>>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
|
1173 |
+
|
1174 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1175 |
+
>>> outputs = model(**inputs)
|
1176 |
+
|
1177 |
+
>>> prediction_logits = outputs.prediction_logits
|
1178 |
+
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
1179 |
+
```
|
1180 |
+
"""
|
1181 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1182 |
+
|
1183 |
+
outputs = self.bert(
|
1184 |
+
input_ids,
|
1185 |
+
attention_mask=attention_mask,
|
1186 |
+
token_type_ids=token_type_ids,
|
1187 |
+
position_ids=position_ids,
|
1188 |
+
head_mask=head_mask,
|
1189 |
+
inputs_embeds=inputs_embeds,
|
1190 |
+
output_attentions=output_attentions,
|
1191 |
+
output_hidden_states=output_hidden_states,
|
1192 |
+
return_dict=return_dict,
|
1193 |
+
)
|
1194 |
+
|
1195 |
+
sequence_output, pooled_output = outputs[:2]
|
1196 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
1197 |
+
|
1198 |
+
total_loss = None
|
1199 |
+
if labels is not None and next_sentence_label is not None:
|
1200 |
+
loss_fct = CrossEntropyLoss()
|
1201 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1202 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
1203 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
1204 |
+
|
1205 |
+
if not return_dict:
|
1206 |
+
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
1207 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1208 |
+
|
1209 |
+
return BertForPreTrainingOutput(
|
1210 |
+
loss=total_loss,
|
1211 |
+
prediction_logits=prediction_scores,
|
1212 |
+
seq_relationship_logits=seq_relationship_score,
|
1213 |
+
hidden_states=outputs.hidden_states,
|
1214 |
+
attentions=outputs.attentions,
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
|
1218 |
+
@add_start_docstrings(
|
1219 |
+
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
|
1220 |
+
)
|
1221 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
1222 |
+
|
1223 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1224 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1225 |
+
|
1226 |
+
def __init__(self, config):
|
1227 |
+
super().__init__(config)
|
1228 |
+
|
1229 |
+
if not config.is_decoder:
|
1230 |
+
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
1231 |
+
|
1232 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1233 |
+
self.cls = BertOnlyMLMHead(config)
|
1234 |
+
|
1235 |
+
# Initialize weights and apply final processing
|
1236 |
+
self.post_init()
|
1237 |
+
|
1238 |
+
def get_output_embeddings(self):
|
1239 |
+
return self.cls.predictions.decoder
|
1240 |
+
|
1241 |
+
def set_output_embeddings(self, new_embeddings):
|
1242 |
+
self.cls.predictions.decoder = new_embeddings
|
1243 |
+
|
1244 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1245 |
+
@add_code_sample_docstrings(
|
1246 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1247 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1248 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1249 |
+
config_class=_CONFIG_FOR_DOC,
|
1250 |
+
)
|
1251 |
+
def forward(
|
1252 |
+
self,
|
1253 |
+
input_ids: Optional[torch.Tensor] = None,
|
1254 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1255 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1256 |
+
position_ids: Optional[torch.Tensor] = None,
|
1257 |
+
head_mask: Optional[torch.Tensor] = None,
|
1258 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1259 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1260 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1261 |
+
labels: Optional[torch.Tensor] = None,
|
1262 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
1263 |
+
use_cache: Optional[bool] = None,
|
1264 |
+
output_attentions: Optional[bool] = None,
|
1265 |
+
output_hidden_states: Optional[bool] = None,
|
1266 |
+
return_dict: Optional[bool] = None,
|
1267 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
1268 |
+
r"""
|
1269 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1270 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1271 |
+
the model is configured as a decoder.
|
1272 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1273 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1274 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
1275 |
+
|
1276 |
+
- 1 for tokens that are **not masked**,
|
1277 |
+
- 0 for tokens that are **masked**.
|
1278 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1279 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1280 |
+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
1281 |
+
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
1282 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1283 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1284 |
+
|
1285 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1286 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1287 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1288 |
+
use_cache (`bool`, *optional*):
|
1289 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1290 |
+
`past_key_values`).
|
1291 |
+
"""
|
1292 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1293 |
+
if labels is not None:
|
1294 |
+
use_cache = False
|
1295 |
+
|
1296 |
+
outputs = self.bert(
|
1297 |
+
input_ids,
|
1298 |
+
attention_mask=attention_mask,
|
1299 |
+
token_type_ids=token_type_ids,
|
1300 |
+
position_ids=position_ids,
|
1301 |
+
head_mask=head_mask,
|
1302 |
+
inputs_embeds=inputs_embeds,
|
1303 |
+
encoder_hidden_states=encoder_hidden_states,
|
1304 |
+
encoder_attention_mask=encoder_attention_mask,
|
1305 |
+
past_key_values=past_key_values,
|
1306 |
+
use_cache=use_cache,
|
1307 |
+
output_attentions=output_attentions,
|
1308 |
+
output_hidden_states=output_hidden_states,
|
1309 |
+
return_dict=return_dict,
|
1310 |
+
)
|
1311 |
+
|
1312 |
+
sequence_output = outputs[0]
|
1313 |
+
prediction_scores = self.cls(sequence_output)
|
1314 |
+
|
1315 |
+
lm_loss = None
|
1316 |
+
if labels is not None:
|
1317 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1318 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1319 |
+
labels = labels[:, 1:].contiguous()
|
1320 |
+
loss_fct = CrossEntropyLoss()
|
1321 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1322 |
+
|
1323 |
+
if not return_dict:
|
1324 |
+
output = (prediction_scores,) + outputs[2:]
|
1325 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1326 |
+
|
1327 |
+
return CausalLMOutputWithCrossAttentions(
|
1328 |
+
loss=lm_loss,
|
1329 |
+
logits=prediction_scores,
|
1330 |
+
past_key_values=outputs.past_key_values,
|
1331 |
+
hidden_states=outputs.hidden_states,
|
1332 |
+
attentions=outputs.attentions,
|
1333 |
+
cross_attentions=outputs.cross_attentions,
|
1334 |
+
)
|
1335 |
+
|
1336 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
1337 |
+
input_shape = input_ids.shape
|
1338 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1339 |
+
if attention_mask is None:
|
1340 |
+
attention_mask = input_ids.new_ones(input_shape)
|
1341 |
+
|
1342 |
+
# cut decoder_input_ids if past is used
|
1343 |
+
if past is not None:
|
1344 |
+
input_ids = input_ids[:, -1:]
|
1345 |
+
|
1346 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
1347 |
+
|
1348 |
+
def _reorder_cache(self, past, beam_idx):
|
1349 |
+
reordered_past = ()
|
1350 |
+
for layer_past in past:
|
1351 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1352 |
+
return reordered_past
|
1353 |
+
|
1354 |
+
|
1355 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
|
1356 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1357 |
+
|
1358 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1359 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1360 |
+
|
1361 |
+
def __init__(self, config):
|
1362 |
+
super().__init__(config)
|
1363 |
+
|
1364 |
+
if config.is_decoder:
|
1365 |
+
logger.warning(
|
1366 |
+
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
1367 |
+
"bi-directional self-attention."
|
1368 |
+
)
|
1369 |
+
|
1370 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1371 |
+
|
1372 |
+
# CZ: 添加了pooling
|
1373 |
+
# self.bert = BertModel(config)
|
1374 |
+
self.cls = BertOnlyMLMHead(config)
|
1375 |
+
|
1376 |
+
# Initialize weights and apply final processing
|
1377 |
+
self.post_init()
|
1378 |
+
|
1379 |
+
def get_output_embeddings(self):
|
1380 |
+
return self.cls.predictions.decoder
|
1381 |
+
|
1382 |
+
def set_output_embeddings(self, new_embeddings):
|
1383 |
+
self.cls.predictions.decoder = new_embeddings
|
1384 |
+
|
1385 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1386 |
+
@add_code_sample_docstrings(
|
1387 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1388 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1389 |
+
output_type=MaskedLMOutput,
|
1390 |
+
config_class=_CONFIG_FOR_DOC,
|
1391 |
+
expected_output="'paris'",
|
1392 |
+
expected_loss=0.88,
|
1393 |
+
)
|
1394 |
+
def forward(
|
1395 |
+
self,
|
1396 |
+
input_ids: Optional[torch.Tensor] = None,
|
1397 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1398 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1399 |
+
position_ids: Optional[torch.Tensor] = None,
|
1400 |
+
head_mask: Optional[torch.Tensor] = None,
|
1401 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1402 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1403 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1404 |
+
labels: Optional[torch.Tensor] = None,
|
1405 |
+
output_attentions: Optional[bool] = None,
|
1406 |
+
output_hidden_states: Optional[bool] = None,
|
1407 |
+
return_dict: Optional[bool] = None,
|
1408 |
+
kpi_ref = None, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
|
1409 |
+
kpi_model = None, # 输入KPI模型
|
1410 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
1411 |
+
r"""
|
1412 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1413 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
1414 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
1415 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
1416 |
+
"""
|
1417 |
+
|
1418 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1419 |
+
|
1420 |
+
outputs, en_loss, scalar_list, con_loss, numeric_input, kpi_input = self.bert(
|
1421 |
+
input_ids,
|
1422 |
+
attention_mask=attention_mask,
|
1423 |
+
token_type_ids=token_type_ids,
|
1424 |
+
position_ids=position_ids,
|
1425 |
+
head_mask=head_mask,
|
1426 |
+
inputs_embeds=inputs_embeds,
|
1427 |
+
encoder_hidden_states=encoder_hidden_states,
|
1428 |
+
encoder_attention_mask=encoder_attention_mask,
|
1429 |
+
output_attentions=output_attentions,
|
1430 |
+
output_hidden_states=output_hidden_states,
|
1431 |
+
return_dict=return_dict,
|
1432 |
+
kpi_ref=kpi_ref, # KPI数值替换的位置,以及参考的KPI name,KPI数值,类别
|
1433 |
+
kpi_model=kpi_model,
|
1434 |
+
)
|
1435 |
+
|
1436 |
+
|
1437 |
+
# decode 等loss 在这里计算
|
1438 |
+
# 数值位本来就被maskl不参与loss计算,可以单独算loss
|
1439 |
+
sequence_output = outputs[0]
|
1440 |
+
|
1441 |
+
# 可能出现没有kpi的情况
|
1442 |
+
# if kpi_ref is None:
|
1443 |
+
# kpi_loss = None
|
1444 |
+
# else:
|
1445 |
+
# # awl需要
|
1446 |
+
# kpi_loss = torch.tensor([0.]).cuda()
|
1447 |
+
kpi_loss, kpi_loss_weight, kpi_loss_dict = None, None, None
|
1448 |
+
if kpi_input is not None:
|
1449 |
+
max_len = sequence_output.shape[1]
|
1450 |
+
# 生成数值embedding
|
1451 |
+
kpi_emb_list = []
|
1452 |
+
for i in range(len(kpi_ref)):
|
1453 |
+
if len(kpi_ref[i])>0:
|
1454 |
+
for item in kpi_ref[i]:
|
1455 |
+
# 可能[NUM]被截断了
|
1456 |
+
if item[2]>=max_len:
|
1457 |
+
continue
|
1458 |
+
# requires_grad=True
|
1459 |
+
kpi_emb_list.append(sequence_output[i][item[2]])
|
1460 |
+
|
1461 |
+
# TODO: 把KPI con loss 归一化,因为KPI会浮动
|
1462 |
+
kpi_emb = torch.stack(kpi_emb_list).unsqueeze(1)
|
1463 |
+
|
1464 |
+
# numeric_input: 相关的数值
|
1465 |
+
# kpi_input:相关的KPI id
|
1466 |
+
_dec_kpi_score, de_loss = kpi_model.decoder(numeric_input, kpi_emb)
|
1467 |
+
# pdb.set_trace()
|
1468 |
+
_cls_kpi, cls_loss = kpi_model.classifier(kpi_emb, kpi_input)
|
1469 |
+
# pdb.set_trace()
|
1470 |
+
# pdb.set_trace()
|
1471 |
+
# 提前乘一个系数降低影响
|
1472 |
+
if con_loss is not None:
|
1473 |
+
kpi_loss = kpi_model.loss_awl(de_loss, 0.2 * cls_loss, 0.2 * con_loss) + 0.5 * en_loss
|
1474 |
+
kpi_loss_dict = {'de_loss':de_loss.item(), 'con_loss':con_loss.item(), 'cls_loss':cls_loss.item(), 'en_loss':en_loss.item()}
|
1475 |
+
else:
|
1476 |
+
kpi_loss = kpi_model.loss_awl(de_loss, 0.1 * cls_loss) + 0.5 * en_loss
|
1477 |
+
kpi_loss_dict = {'de_loss':de_loss.item(), 'cls_loss':cls_loss.item(), 'en_loss':en_loss.item()}
|
1478 |
+
kpi_loss_weight = kpi_model.loss_awl.params.tolist()
|
1479 |
+
|
1480 |
+
|
1481 |
+
prediction_scores = self.cls(sequence_output)
|
1482 |
+
|
1483 |
+
masked_lm_loss = None
|
1484 |
+
if labels is not None:
|
1485 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token [ignore_index=- 100]
|
1486 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1487 |
+
|
1488 |
+
|
1489 |
+
if not return_dict:
|
1490 |
+
output = (prediction_scores,) + outputs[2:]
|
1491 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1492 |
+
|
1493 |
+
# pdb.set_trace()
|
1494 |
+
return MaskedLMOutput(
|
1495 |
+
loss=masked_lm_loss,
|
1496 |
+
logits=prediction_scores,
|
1497 |
+
hidden_states=outputs.hidden_states,
|
1498 |
+
attentions=outputs.attentions
|
1499 |
+
), kpi_loss, kpi_loss_weight, kpi_loss_dict
|
1500 |
+
|
1501 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
1502 |
+
input_shape = input_ids.shape
|
1503 |
+
effective_batch_size = input_shape[0]
|
1504 |
+
|
1505 |
+
# add a dummy token
|
1506 |
+
if self.config.pad_token_id is None:
|
1507 |
+
raise ValueError("The PAD token should be defined for generation")
|
1508 |
+
|
1509 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
1510 |
+
dummy_token = torch.full(
|
1511 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
1512 |
+
)
|
1513 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
1514 |
+
|
1515 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
1516 |
+
|
1517 |
+
|
1518 |
+
@add_start_docstrings(
|
1519 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
|
1520 |
+
BERT_START_DOCSTRING,
|
1521 |
+
)
|
1522 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
1523 |
+
def __init__(self, config):
|
1524 |
+
super().__init__(config)
|
1525 |
+
|
1526 |
+
self.bert = BertModel(config)
|
1527 |
+
self.cls = BertOnlyNSPHead(config)
|
1528 |
+
|
1529 |
+
# Initialize weights and apply final processing
|
1530 |
+
self.post_init()
|
1531 |
+
|
1532 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1533 |
+
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
1534 |
+
def forward(
|
1535 |
+
self,
|
1536 |
+
input_ids: Optional[torch.Tensor] = None,
|
1537 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1538 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1539 |
+
position_ids: Optional[torch.Tensor] = None,
|
1540 |
+
head_mask: Optional[torch.Tensor] = None,
|
1541 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1542 |
+
labels: Optional[torch.Tensor] = None,
|
1543 |
+
output_attentions: Optional[bool] = None,
|
1544 |
+
output_hidden_states: Optional[bool] = None,
|
1545 |
+
return_dict: Optional[bool] = None,
|
1546 |
+
**kwargs,
|
1547 |
+
) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
1548 |
+
r"""
|
1549 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1550 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
1551 |
+
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
1552 |
+
|
1553 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
1554 |
+
- 1 indicates sequence B is a random sequence.
|
1555 |
+
|
1556 |
+
Returns:
|
1557 |
+
|
1558 |
+
Example:
|
1559 |
+
|
1560 |
+
```python
|
1561 |
+
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
|
1562 |
+
>>> import torch
|
1563 |
+
|
1564 |
+
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
1565 |
+
>>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
1566 |
+
|
1567 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1568 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
1569 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
1570 |
+
|
1571 |
+
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
1572 |
+
>>> logits = outputs.logits
|
1573 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
1574 |
+
```
|
1575 |
+
"""
|
1576 |
+
|
1577 |
+
if "next_sentence_label" in kwargs:
|
1578 |
+
warnings.warn(
|
1579 |
+
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
1580 |
+
" `labels` instead.",
|
1581 |
+
FutureWarning,
|
1582 |
+
)
|
1583 |
+
labels = kwargs.pop("next_sentence_label")
|
1584 |
+
|
1585 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1586 |
+
|
1587 |
+
outputs = self.bert(
|
1588 |
+
input_ids,
|
1589 |
+
attention_mask=attention_mask,
|
1590 |
+
token_type_ids=token_type_ids,
|
1591 |
+
position_ids=position_ids,
|
1592 |
+
head_mask=head_mask,
|
1593 |
+
inputs_embeds=inputs_embeds,
|
1594 |
+
output_attentions=output_attentions,
|
1595 |
+
output_hidden_states=output_hidden_states,
|
1596 |
+
return_dict=return_dict,
|
1597 |
+
)
|
1598 |
+
|
1599 |
+
pooled_output = outputs[1]
|
1600 |
+
|
1601 |
+
seq_relationship_scores = self.cls(pooled_output)
|
1602 |
+
|
1603 |
+
next_sentence_loss = None
|
1604 |
+
if labels is not None:
|
1605 |
+
loss_fct = CrossEntropyLoss()
|
1606 |
+
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
1607 |
+
|
1608 |
+
if not return_dict:
|
1609 |
+
output = (seq_relationship_scores,) + outputs[2:]
|
1610 |
+
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
1611 |
+
|
1612 |
+
return NextSentencePredictorOutput(
|
1613 |
+
loss=next_sentence_loss,
|
1614 |
+
logits=seq_relationship_scores,
|
1615 |
+
hidden_states=outputs.hidden_states,
|
1616 |
+
attentions=outputs.attentions,
|
1617 |
+
)
|
1618 |
+
|
1619 |
+
|
1620 |
+
@add_start_docstrings(
|
1621 |
+
"""
|
1622 |
+
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
1623 |
+
output) e.g. for GLUE tasks.
|
1624 |
+
""",
|
1625 |
+
BERT_START_DOCSTRING,
|
1626 |
+
)
|
1627 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
1628 |
+
def __init__(self, config):
|
1629 |
+
super().__init__(config)
|
1630 |
+
self.num_labels = config.num_labels
|
1631 |
+
self.config = config
|
1632 |
+
|
1633 |
+
self.bert = BertModel(config)
|
1634 |
+
classifier_dropout = (
|
1635 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1636 |
+
)
|
1637 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1638 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1639 |
+
|
1640 |
+
# Initialize weights and apply final processing
|
1641 |
+
self.post_init()
|
1642 |
+
|
1643 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1644 |
+
@add_code_sample_docstrings(
|
1645 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1646 |
+
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
|
1647 |
+
output_type=SequenceClassifierOutput,
|
1648 |
+
config_class=_CONFIG_FOR_DOC,
|
1649 |
+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
1650 |
+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
1651 |
+
)
|
1652 |
+
def forward(
|
1653 |
+
self,
|
1654 |
+
input_ids: Optional[torch.Tensor] = None,
|
1655 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1656 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1657 |
+
position_ids: Optional[torch.Tensor] = None,
|
1658 |
+
head_mask: Optional[torch.Tensor] = None,
|
1659 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1660 |
+
labels: Optional[torch.Tensor] = None,
|
1661 |
+
output_attentions: Optional[bool] = None,
|
1662 |
+
output_hidden_states: Optional[bool] = None,
|
1663 |
+
return_dict: Optional[bool] = None,
|
1664 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
1665 |
+
r"""
|
1666 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1667 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1668 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1669 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1670 |
+
"""
|
1671 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1672 |
+
|
1673 |
+
outputs = self.bert(
|
1674 |
+
input_ids,
|
1675 |
+
attention_mask=attention_mask,
|
1676 |
+
token_type_ids=token_type_ids,
|
1677 |
+
position_ids=position_ids,
|
1678 |
+
head_mask=head_mask,
|
1679 |
+
inputs_embeds=inputs_embeds,
|
1680 |
+
output_attentions=output_attentions,
|
1681 |
+
output_hidden_states=output_hidden_states,
|
1682 |
+
return_dict=return_dict,
|
1683 |
+
)
|
1684 |
+
|
1685 |
+
pooled_output = outputs[1]
|
1686 |
+
|
1687 |
+
pooled_output = self.dropout(pooled_output)
|
1688 |
+
logits = self.classifier(pooled_output)
|
1689 |
+
|
1690 |
+
loss = None
|
1691 |
+
if labels is not None:
|
1692 |
+
if self.config.problem_type is None:
|
1693 |
+
if self.num_labels == 1:
|
1694 |
+
self.config.problem_type = "regression"
|
1695 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1696 |
+
self.config.problem_type = "single_label_classification"
|
1697 |
+
else:
|
1698 |
+
self.config.problem_type = "multi_label_classification"
|
1699 |
+
|
1700 |
+
if self.config.problem_type == "regression":
|
1701 |
+
loss_fct = MSELoss()
|
1702 |
+
if self.num_labels == 1:
|
1703 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1704 |
+
else:
|
1705 |
+
loss = loss_fct(logits, labels)
|
1706 |
+
elif self.config.problem_type == "single_label_classification":
|
1707 |
+
loss_fct = CrossEntropyLoss()
|
1708 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1709 |
+
elif self.config.problem_type == "multi_label_classification":
|
1710 |
+
loss_fct = BCEWithLogitsLoss()
|
1711 |
+
loss = loss_fct(logits, labels)
|
1712 |
+
if not return_dict:
|
1713 |
+
output = (logits,) + outputs[2:]
|
1714 |
+
return ((loss,) + output) if loss is not None else output
|
1715 |
+
|
1716 |
+
return SequenceClassifierOutput(
|
1717 |
+
loss=loss,
|
1718 |
+
logits=logits,
|
1719 |
+
hidden_states=outputs.hidden_states,
|
1720 |
+
attentions=outputs.attentions,
|
1721 |
+
)
|
1722 |
+
|
1723 |
+
|
1724 |
+
@add_start_docstrings(
|
1725 |
+
"""
|
1726 |
+
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
1727 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
1728 |
+
""",
|
1729 |
+
BERT_START_DOCSTRING,
|
1730 |
+
)
|
1731 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
1732 |
+
def __init__(self, config):
|
1733 |
+
super().__init__(config)
|
1734 |
+
|
1735 |
+
self.bert = BertModel(config)
|
1736 |
+
classifier_dropout = (
|
1737 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1738 |
+
)
|
1739 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1740 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
1741 |
+
|
1742 |
+
# Initialize weights and apply final processing
|
1743 |
+
self.post_init()
|
1744 |
+
|
1745 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
1746 |
+
@add_code_sample_docstrings(
|
1747 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1748 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1749 |
+
output_type=MultipleChoiceModelOutput,
|
1750 |
+
config_class=_CONFIG_FOR_DOC,
|
1751 |
+
)
|
1752 |
+
def forward(
|
1753 |
+
self,
|
1754 |
+
input_ids: Optional[torch.Tensor] = None,
|
1755 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1756 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1757 |
+
position_ids: Optional[torch.Tensor] = None,
|
1758 |
+
head_mask: Optional[torch.Tensor] = None,
|
1759 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1760 |
+
labels: Optional[torch.Tensor] = None,
|
1761 |
+
output_attentions: Optional[bool] = None,
|
1762 |
+
output_hidden_states: Optional[bool] = None,
|
1763 |
+
return_dict: Optional[bool] = None,
|
1764 |
+
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
1765 |
+
r"""
|
1766 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1767 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
1768 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
1769 |
+
`input_ids` above)
|
1770 |
+
"""
|
1771 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1772 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
1773 |
+
|
1774 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
1775 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
1776 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
1777 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
1778 |
+
inputs_embeds = (
|
1779 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
1780 |
+
if inputs_embeds is not None
|
1781 |
+
else None
|
1782 |
+
)
|
1783 |
+
|
1784 |
+
outputs = self.bert(
|
1785 |
+
input_ids,
|
1786 |
+
attention_mask=attention_mask,
|
1787 |
+
token_type_ids=token_type_ids,
|
1788 |
+
position_ids=position_ids,
|
1789 |
+
head_mask=head_mask,
|
1790 |
+
inputs_embeds=inputs_embeds,
|
1791 |
+
output_attentions=output_attentions,
|
1792 |
+
output_hidden_states=output_hidden_states,
|
1793 |
+
return_dict=return_dict,
|
1794 |
+
)
|
1795 |
+
|
1796 |
+
pooled_output = outputs[1]
|
1797 |
+
|
1798 |
+
pooled_output = self.dropout(pooled_output)
|
1799 |
+
logits = self.classifier(pooled_output)
|
1800 |
+
reshaped_logits = logits.view(-1, num_choices)
|
1801 |
+
|
1802 |
+
loss = None
|
1803 |
+
if labels is not None:
|
1804 |
+
loss_fct = CrossEntropyLoss()
|
1805 |
+
loss = loss_fct(reshaped_logits, labels)
|
1806 |
+
|
1807 |
+
if not return_dict:
|
1808 |
+
output = (reshaped_logits,) + outputs[2:]
|
1809 |
+
return ((loss,) + output) if loss is not None else output
|
1810 |
+
|
1811 |
+
return MultipleChoiceModelOutput(
|
1812 |
+
loss=loss,
|
1813 |
+
logits=reshaped_logits,
|
1814 |
+
hidden_states=outputs.hidden_states,
|
1815 |
+
attentions=outputs.attentions,
|
1816 |
+
)
|
1817 |
+
|
1818 |
+
|
1819 |
+
@add_start_docstrings(
|
1820 |
+
"""
|
1821 |
+
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1822 |
+
Named-Entity-Recognition (NER) tasks.
|
1823 |
+
""",
|
1824 |
+
BERT_START_DOCSTRING,
|
1825 |
+
)
|
1826 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
1827 |
+
|
1828 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1829 |
+
|
1830 |
+
def __init__(self, config):
|
1831 |
+
super().__init__(config)
|
1832 |
+
self.num_labels = config.num_labels
|
1833 |
+
|
1834 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1835 |
+
classifier_dropout = (
|
1836 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1837 |
+
)
|
1838 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1839 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1840 |
+
|
1841 |
+
# Initialize weights and apply final processing
|
1842 |
+
self.post_init()
|
1843 |
+
|
1844 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1845 |
+
@add_code_sample_docstrings(
|
1846 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1847 |
+
checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
|
1848 |
+
output_type=TokenClassifierOutput,
|
1849 |
+
config_class=_CONFIG_FOR_DOC,
|
1850 |
+
expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
|
1851 |
+
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
|
1852 |
+
)
|
1853 |
+
def forward(
|
1854 |
+
self,
|
1855 |
+
input_ids: Optional[torch.Tensor] = None,
|
1856 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1857 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1858 |
+
position_ids: Optional[torch.Tensor] = None,
|
1859 |
+
head_mask: Optional[torch.Tensor] = None,
|
1860 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1861 |
+
labels: Optional[torch.Tensor] = None,
|
1862 |
+
output_attentions: Optional[bool] = None,
|
1863 |
+
output_hidden_states: Optional[bool] = None,
|
1864 |
+
return_dict: Optional[bool] = None,
|
1865 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
1866 |
+
r"""
|
1867 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1868 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1869 |
+
"""
|
1870 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1871 |
+
|
1872 |
+
outputs = self.bert(
|
1873 |
+
input_ids,
|
1874 |
+
attention_mask=attention_mask,
|
1875 |
+
token_type_ids=token_type_ids,
|
1876 |
+
position_ids=position_ids,
|
1877 |
+
head_mask=head_mask,
|
1878 |
+
inputs_embeds=inputs_embeds,
|
1879 |
+
output_attentions=output_attentions,
|
1880 |
+
output_hidden_states=output_hidden_states,
|
1881 |
+
return_dict=return_dict,
|
1882 |
+
)
|
1883 |
+
|
1884 |
+
sequence_output = outputs[0]
|
1885 |
+
|
1886 |
+
sequence_output = self.dropout(sequence_output)
|
1887 |
+
logits = self.classifier(sequence_output)
|
1888 |
+
|
1889 |
+
loss = None
|
1890 |
+
if labels is not None:
|
1891 |
+
loss_fct = CrossEntropyLoss()
|
1892 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1893 |
+
|
1894 |
+
if not return_dict:
|
1895 |
+
output = (logits,) + outputs[2:]
|
1896 |
+
return ((loss,) + output) if loss is not None else output
|
1897 |
+
|
1898 |
+
return TokenClassifierOutput(
|
1899 |
+
loss=loss,
|
1900 |
+
logits=logits,
|
1901 |
+
hidden_states=outputs.hidden_states,
|
1902 |
+
attentions=outputs.attentions,
|
1903 |
+
)
|
1904 |
+
|
1905 |
+
|
1906 |
+
@add_start_docstrings(
|
1907 |
+
"""
|
1908 |
+
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
1909 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1910 |
+
""",
|
1911 |
+
BERT_START_DOCSTRING,
|
1912 |
+
)
|
1913 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
1914 |
+
|
1915 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1916 |
+
|
1917 |
+
def __init__(self, config):
|
1918 |
+
super().__init__(config)
|
1919 |
+
self.num_labels = config.num_labels
|
1920 |
+
|
1921 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1922 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1923 |
+
|
1924 |
+
# Initialize weights and apply final processing
|
1925 |
+
self.post_init()
|
1926 |
+
|
1927 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1928 |
+
@add_code_sample_docstrings(
|
1929 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1930 |
+
checkpoint=_CHECKPOINT_FOR_QA,
|
1931 |
+
output_type=QuestionAnsweringModelOutput,
|
1932 |
+
config_class=_CONFIG_FOR_DOC,
|
1933 |
+
qa_target_start_index=_QA_TARGET_START_INDEX,
|
1934 |
+
qa_target_end_index=_QA_TARGET_END_INDEX,
|
1935 |
+
expected_output=_QA_EXPECTED_OUTPUT,
|
1936 |
+
expected_loss=_QA_EXPECTED_LOSS,
|
1937 |
+
)
|
1938 |
+
def forward(
|
1939 |
+
self,
|
1940 |
+
input_ids: Optional[torch.Tensor] = None,
|
1941 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1942 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1943 |
+
position_ids: Optional[torch.Tensor] = None,
|
1944 |
+
head_mask: Optional[torch.Tensor] = None,
|
1945 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1946 |
+
start_positions: Optional[torch.Tensor] = None,
|
1947 |
+
end_positions: Optional[torch.Tensor] = None,
|
1948 |
+
output_attentions: Optional[bool] = None,
|
1949 |
+
output_hidden_states: Optional[bool] = None,
|
1950 |
+
return_dict: Optional[bool] = None,
|
1951 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
1952 |
+
r"""
|
1953 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1954 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1955 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1956 |
+
are not taken into account for computing the loss.
|
1957 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1958 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1959 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1960 |
+
are not taken into account for computing the loss.
|
1961 |
+
"""
|
1962 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1963 |
+
|
1964 |
+
outputs = self.bert(
|
1965 |
+
input_ids,
|
1966 |
+
attention_mask=attention_mask,
|
1967 |
+
token_type_ids=token_type_ids,
|
1968 |
+
position_ids=position_ids,
|
1969 |
+
head_mask=head_mask,
|
1970 |
+
inputs_embeds=inputs_embeds,
|
1971 |
+
output_attentions=output_attentions,
|
1972 |
+
output_hidden_states=output_hidden_states,
|
1973 |
+
return_dict=return_dict,
|
1974 |
+
)
|
1975 |
+
|
1976 |
+
sequence_output = outputs[0]
|
1977 |
+
|
1978 |
+
logits = self.qa_outputs(sequence_output)
|
1979 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1980 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1981 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1982 |
+
|
1983 |
+
total_loss = None
|
1984 |
+
if start_positions is not None and end_positions is not None:
|
1985 |
+
# If we are on multi-GPU, split add a dimension
|
1986 |
+
if len(start_positions.size()) > 1:
|
1987 |
+
start_positions = start_positions.squeeze(-1)
|
1988 |
+
if len(end_positions.size()) > 1:
|
1989 |
+
end_positions = end_positions.squeeze(-1)
|
1990 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1991 |
+
ignored_index = start_logits.size(1)
|
1992 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1993 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1994 |
+
|
1995 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1996 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1997 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1998 |
+
total_loss = (start_loss + end_loss) / 2
|
1999 |
+
|
2000 |
+
if not return_dict:
|
2001 |
+
output = (start_logits, end_logits) + outputs[2:]
|
2002 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
2003 |
+
|
2004 |
+
return QuestionAnsweringModelOutput(
|
2005 |
+
loss=total_loss,
|
2006 |
+
start_logits=start_logits,
|
2007 |
+
end_logits=end_logits,
|
2008 |
+
hidden_states=outputs.hidden_states,
|
2009 |
+
attentions=outputs.attentions,
|
2010 |
+
)
|
KTeleBERT/model/bert/tokenization_bert.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for Bert."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import os
|
20 |
+
import unicodedata
|
21 |
+
from typing import List, Optional, Tuple
|
22 |
+
|
23 |
+
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
24 |
+
from transformers.utils import logging
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
30 |
+
|
31 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
32 |
+
"vocab_file": {
|
33 |
+
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
|
34 |
+
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
|
35 |
+
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
|
36 |
+
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
|
37 |
+
"bert-base-multilingual-uncased": (
|
38 |
+
"https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt"
|
39 |
+
),
|
40 |
+
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
|
41 |
+
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
|
42 |
+
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
|
43 |
+
"bert-large-uncased-whole-word-masking": (
|
44 |
+
"https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt"
|
45 |
+
),
|
46 |
+
"bert-large-cased-whole-word-masking": (
|
47 |
+
"https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt"
|
48 |
+
),
|
49 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": (
|
50 |
+
"https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
|
51 |
+
),
|
52 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": (
|
53 |
+
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
|
54 |
+
),
|
55 |
+
"bert-base-cased-finetuned-mrpc": (
|
56 |
+
"https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt"
|
57 |
+
),
|
58 |
+
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
|
59 |
+
"bert-base-german-dbmdz-uncased": (
|
60 |
+
"https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"
|
61 |
+
),
|
62 |
+
"TurkuNLP/bert-base-finnish-cased-v1": (
|
63 |
+
"https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt"
|
64 |
+
),
|
65 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": (
|
66 |
+
"https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt"
|
67 |
+
),
|
68 |
+
"wietsedv/bert-base-dutch-cased": (
|
69 |
+
"https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt"
|
70 |
+
),
|
71 |
+
}
|
72 |
+
}
|
73 |
+
|
74 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
75 |
+
"bert-base-uncased": 512,
|
76 |
+
"bert-large-uncased": 512,
|
77 |
+
"bert-base-cased": 512,
|
78 |
+
"bert-large-cased": 512,
|
79 |
+
"bert-base-multilingual-uncased": 512,
|
80 |
+
"bert-base-multilingual-cased": 512,
|
81 |
+
"bert-base-chinese": 512,
|
82 |
+
"bert-base-german-cased": 512,
|
83 |
+
"bert-large-uncased-whole-word-masking": 512,
|
84 |
+
"bert-large-cased-whole-word-masking": 512,
|
85 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
|
86 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
|
87 |
+
"bert-base-cased-finetuned-mrpc": 512,
|
88 |
+
"bert-base-german-dbmdz-cased": 512,
|
89 |
+
"bert-base-german-dbmdz-uncased": 512,
|
90 |
+
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
91 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
92 |
+
"wietsedv/bert-base-dutch-cased": 512,
|
93 |
+
}
|
94 |
+
|
95 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
96 |
+
"bert-base-uncased": {"do_lower_case": True},
|
97 |
+
"bert-large-uncased": {"do_lower_case": True},
|
98 |
+
"bert-base-cased": {"do_lower_case": False},
|
99 |
+
"bert-large-cased": {"do_lower_case": False},
|
100 |
+
"bert-base-multilingual-uncased": {"do_lower_case": True},
|
101 |
+
"bert-base-multilingual-cased": {"do_lower_case": False},
|
102 |
+
"bert-base-chinese": {"do_lower_case": False},
|
103 |
+
"bert-base-german-cased": {"do_lower_case": False},
|
104 |
+
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
|
105 |
+
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
|
106 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
|
107 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
|
108 |
+
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
109 |
+
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
110 |
+
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
111 |
+
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
112 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
113 |
+
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
114 |
+
}
|
115 |
+
|
116 |
+
|
117 |
+
def load_vocab(vocab_file):
|
118 |
+
"""Loads a vocabulary file into a dictionary."""
|
119 |
+
vocab = collections.OrderedDict()
|
120 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
121 |
+
tokens = reader.readlines()
|
122 |
+
for index, token in enumerate(tokens):
|
123 |
+
token = token.rstrip("\n")
|
124 |
+
vocab[token] = index
|
125 |
+
return vocab
|
126 |
+
|
127 |
+
|
128 |
+
def whitespace_tokenize(text):
|
129 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
130 |
+
text = text.strip()
|
131 |
+
if not text:
|
132 |
+
return []
|
133 |
+
tokens = text.split()
|
134 |
+
return tokens
|
135 |
+
|
136 |
+
|
137 |
+
class BertTokenizer(PreTrainedTokenizer):
|
138 |
+
r"""
|
139 |
+
Construct a BERT tokenizer. Based on WordPiece.
|
140 |
+
|
141 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
142 |
+
this superclass for more information regarding those methods.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
vocab_file (`str`):
|
146 |
+
File containing the vocabulary.
|
147 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
148 |
+
Whether or not to lowercase the input when tokenizing.
|
149 |
+
do_basic_tokenize (`bool`, *optional*, defaults to `True`):
|
150 |
+
Whether or not to do basic tokenization before WordPiece.
|
151 |
+
never_split (`Iterable`, *optional*):
|
152 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
153 |
+
`do_basic_tokenize=True`
|
154 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
155 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
156 |
+
token instead.
|
157 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
158 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
159 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
160 |
+
token of a sequence built with special tokens.
|
161 |
+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
|
162 |
+
The token used for padding, for example when batching sequences of different lengths.
|
163 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
164 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
165 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
166 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
167 |
+
The token used for masking values. This is the token used when training this model with masked language
|
168 |
+
modeling. This is the token which the model will try to predict.
|
169 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
170 |
+
Whether or not to tokenize Chinese characters.
|
171 |
+
|
172 |
+
This should likely be deactivated for Japanese (see this
|
173 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
174 |
+
strip_accents (`bool`, *optional*):
|
175 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
176 |
+
value for `lowercase` (as in the original BERT).
|
177 |
+
"""
|
178 |
+
|
179 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
180 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
181 |
+
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
182 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
vocab_file,
|
187 |
+
do_lower_case=True,
|
188 |
+
do_basic_tokenize=True,
|
189 |
+
never_split=None,
|
190 |
+
unk_token="[UNK]",
|
191 |
+
sep_token="[SEP]",
|
192 |
+
pad_token="[PAD]",
|
193 |
+
cls_token="[CLS]",
|
194 |
+
mask_token="[MASK]",
|
195 |
+
tokenize_chinese_chars=True,
|
196 |
+
strip_accents=None,
|
197 |
+
**kwargs
|
198 |
+
):
|
199 |
+
super().__init__(
|
200 |
+
do_lower_case=do_lower_case,
|
201 |
+
do_basic_tokenize=do_basic_tokenize,
|
202 |
+
never_split=never_split,
|
203 |
+
unk_token=unk_token,
|
204 |
+
sep_token=sep_token,
|
205 |
+
pad_token=pad_token,
|
206 |
+
cls_token=cls_token,
|
207 |
+
mask_token=mask_token,
|
208 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
209 |
+
strip_accents=strip_accents,
|
210 |
+
**kwargs,
|
211 |
+
)
|
212 |
+
|
213 |
+
if not os.path.isfile(vocab_file):
|
214 |
+
raise ValueError(
|
215 |
+
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
216 |
+
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
217 |
+
)
|
218 |
+
self.vocab = load_vocab(vocab_file)
|
219 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
220 |
+
self.do_basic_tokenize = do_basic_tokenize
|
221 |
+
if do_basic_tokenize:
|
222 |
+
self.basic_tokenizer = BasicTokenizer(
|
223 |
+
do_lower_case=do_lower_case,
|
224 |
+
never_split=never_split,
|
225 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
226 |
+
strip_accents=strip_accents,
|
227 |
+
)
|
228 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
229 |
+
|
230 |
+
@property
|
231 |
+
def do_lower_case(self):
|
232 |
+
return self.basic_tokenizer.do_lower_case
|
233 |
+
|
234 |
+
@property
|
235 |
+
def vocab_size(self):
|
236 |
+
return len(self.vocab)
|
237 |
+
|
238 |
+
def get_vocab(self):
|
239 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
240 |
+
|
241 |
+
def _tokenize(self, text):
|
242 |
+
split_tokens = []
|
243 |
+
if self.do_basic_tokenize:
|
244 |
+
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
245 |
+
|
246 |
+
# If the token is part of the never_split set
|
247 |
+
if token in self.basic_tokenizer.never_split:
|
248 |
+
split_tokens.append(token)
|
249 |
+
else:
|
250 |
+
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
251 |
+
else:
|
252 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
253 |
+
return split_tokens
|
254 |
+
|
255 |
+
def _convert_token_to_id(self, token):
|
256 |
+
"""Converts a token (str) in an id using the vocab."""
|
257 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
258 |
+
|
259 |
+
def _convert_id_to_token(self, index):
|
260 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
261 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
262 |
+
|
263 |
+
def convert_tokens_to_string(self, tokens):
|
264 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
265 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
266 |
+
return out_string
|
267 |
+
|
268 |
+
def build_inputs_with_special_tokens(
|
269 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
270 |
+
) -> List[int]:
|
271 |
+
"""
|
272 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
273 |
+
adding special tokens. A BERT sequence has the following format:
|
274 |
+
|
275 |
+
- single sequence: `[CLS] X [SEP]`
|
276 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
277 |
+
|
278 |
+
Args:
|
279 |
+
token_ids_0 (`List[int]`):
|
280 |
+
List of IDs to which the special tokens will be added.
|
281 |
+
token_ids_1 (`List[int]`, *optional*):
|
282 |
+
Optional second list of IDs for sequence pairs.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
286 |
+
"""
|
287 |
+
if token_ids_1 is None:
|
288 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
289 |
+
cls = [self.cls_token_id]
|
290 |
+
sep = [self.sep_token_id]
|
291 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
292 |
+
|
293 |
+
def get_special_tokens_mask(
|
294 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
295 |
+
) -> List[int]:
|
296 |
+
"""
|
297 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
298 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
token_ids_0 (`List[int]`):
|
302 |
+
List of IDs.
|
303 |
+
token_ids_1 (`List[int]`, *optional*):
|
304 |
+
Optional second list of IDs for sequence pairs.
|
305 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
306 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
310 |
+
"""
|
311 |
+
|
312 |
+
if already_has_special_tokens:
|
313 |
+
return super().get_special_tokens_mask(
|
314 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
315 |
+
)
|
316 |
+
|
317 |
+
if token_ids_1 is not None:
|
318 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
319 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
320 |
+
|
321 |
+
def create_token_type_ids_from_sequences(
|
322 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
323 |
+
) -> List[int]:
|
324 |
+
"""
|
325 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
326 |
+
pair mask has the following format:
|
327 |
+
|
328 |
+
```
|
329 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
330 |
+
| first sequence | second sequence |
|
331 |
+
```
|
332 |
+
|
333 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
334 |
+
|
335 |
+
Args:
|
336 |
+
token_ids_0 (`List[int]`):
|
337 |
+
List of IDs.
|
338 |
+
token_ids_1 (`List[int]`, *optional*):
|
339 |
+
Optional second list of IDs for sequence pairs.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
343 |
+
"""
|
344 |
+
sep = [self.sep_token_id]
|
345 |
+
cls = [self.cls_token_id]
|
346 |
+
if token_ids_1 is None:
|
347 |
+
return len(cls + token_ids_0 + sep) * [0]
|
348 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
349 |
+
|
350 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
351 |
+
index = 0
|
352 |
+
if os.path.isdir(save_directory):
|
353 |
+
vocab_file = os.path.join(
|
354 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
358 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
359 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
360 |
+
if index != token_index:
|
361 |
+
logger.warning(
|
362 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
363 |
+
" Please check that the vocabulary is not corrupted!"
|
364 |
+
)
|
365 |
+
index = token_index
|
366 |
+
writer.write(token + "\n")
|
367 |
+
index += 1
|
368 |
+
return (vocab_file,)
|
369 |
+
|
370 |
+
|
371 |
+
class BasicTokenizer(object):
|
372 |
+
"""
|
373 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
374 |
+
|
375 |
+
Args:
|
376 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
377 |
+
Whether or not to lowercase the input when tokenizing.
|
378 |
+
never_split (`Iterable`, *optional*):
|
379 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
380 |
+
`do_basic_tokenize=True`
|
381 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
382 |
+
Whether or not to tokenize Chinese characters.
|
383 |
+
|
384 |
+
This should likely be deactivated for Japanese (see this
|
385 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
386 |
+
strip_accents: (`bool`, *optional*):
|
387 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
388 |
+
value for `lowercase` (as in the original BERT).
|
389 |
+
"""
|
390 |
+
|
391 |
+
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
|
392 |
+
if never_split is None:
|
393 |
+
never_split = []
|
394 |
+
self.do_lower_case = do_lower_case
|
395 |
+
self.never_split = set(never_split)
|
396 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
397 |
+
self.strip_accents = strip_accents
|
398 |
+
|
399 |
+
def tokenize(self, text, never_split=None):
|
400 |
+
"""
|
401 |
+
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
402 |
+
WordPieceTokenizer.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
never_split (`List[str]`, *optional*)
|
406 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
407 |
+
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
408 |
+
"""
|
409 |
+
# union() returns a new set by concatenating the two sets.
|
410 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
411 |
+
text = self._clean_text(text)
|
412 |
+
|
413 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
414 |
+
# models. This is also applied to the English models now, but it doesn't
|
415 |
+
# matter since the English models were not trained on any Chinese data
|
416 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
417 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
418 |
+
# words in the English Wikipedia.).
|
419 |
+
if self.tokenize_chinese_chars:
|
420 |
+
text = self._tokenize_chinese_chars(text)
|
421 |
+
orig_tokens = whitespace_tokenize(text)
|
422 |
+
split_tokens = []
|
423 |
+
for token in orig_tokens:
|
424 |
+
if token not in never_split:
|
425 |
+
if self.do_lower_case:
|
426 |
+
token = token.lower()
|
427 |
+
if self.strip_accents is not False:
|
428 |
+
token = self._run_strip_accents(token)
|
429 |
+
elif self.strip_accents:
|
430 |
+
token = self._run_strip_accents(token)
|
431 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
432 |
+
|
433 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
434 |
+
return output_tokens
|
435 |
+
|
436 |
+
def _run_strip_accents(self, text):
|
437 |
+
"""Strips accents from a piece of text."""
|
438 |
+
text = unicodedata.normalize("NFD", text)
|
439 |
+
output = []
|
440 |
+
for char in text:
|
441 |
+
cat = unicodedata.category(char)
|
442 |
+
if cat == "Mn":
|
443 |
+
continue
|
444 |
+
output.append(char)
|
445 |
+
return "".join(output)
|
446 |
+
|
447 |
+
def _run_split_on_punc(self, text, never_split=None):
|
448 |
+
"""Splits punctuation on a piece of text."""
|
449 |
+
if never_split is not None and text in never_split:
|
450 |
+
return [text]
|
451 |
+
chars = list(text)
|
452 |
+
i = 0
|
453 |
+
start_new_word = True
|
454 |
+
output = []
|
455 |
+
while i < len(chars):
|
456 |
+
char = chars[i]
|
457 |
+
if _is_punctuation(char):
|
458 |
+
output.append([char])
|
459 |
+
start_new_word = True
|
460 |
+
else:
|
461 |
+
if start_new_word:
|
462 |
+
output.append([])
|
463 |
+
start_new_word = False
|
464 |
+
output[-1].append(char)
|
465 |
+
i += 1
|
466 |
+
|
467 |
+
return ["".join(x) for x in output]
|
468 |
+
|
469 |
+
def _tokenize_chinese_chars(self, text):
|
470 |
+
"""Adds whitespace around any CJK character."""
|
471 |
+
output = []
|
472 |
+
for char in text:
|
473 |
+
cp = ord(char)
|
474 |
+
if self._is_chinese_char(cp):
|
475 |
+
output.append(" ")
|
476 |
+
output.append(char)
|
477 |
+
output.append(" ")
|
478 |
+
else:
|
479 |
+
output.append(char)
|
480 |
+
return "".join(output)
|
481 |
+
|
482 |
+
def _is_chinese_char(self, cp):
|
483 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
484 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
485 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
486 |
+
#
|
487 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
488 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
489 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
490 |
+
# space-separated words, so they are not treated specially and handled
|
491 |
+
# like the all of the other languages.
|
492 |
+
if (
|
493 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
494 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
495 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
496 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
497 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
498 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
499 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
500 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
501 |
+
): #
|
502 |
+
return True
|
503 |
+
|
504 |
+
return False
|
505 |
+
|
506 |
+
def _clean_text(self, text):
|
507 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
508 |
+
output = []
|
509 |
+
for char in text:
|
510 |
+
cp = ord(char)
|
511 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
512 |
+
continue
|
513 |
+
if _is_whitespace(char):
|
514 |
+
output.append(" ")
|
515 |
+
else:
|
516 |
+
output.append(char)
|
517 |
+
return "".join(output)
|
518 |
+
|
519 |
+
|
520 |
+
class WordpieceTokenizer(object):
|
521 |
+
"""Runs WordPiece tokenization."""
|
522 |
+
|
523 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
524 |
+
self.vocab = vocab
|
525 |
+
self.unk_token = unk_token
|
526 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
527 |
+
|
528 |
+
def tokenize(self, text):
|
529 |
+
"""
|
530 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
531 |
+
tokenization using the given vocabulary.
|
532 |
+
|
533 |
+
For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
text: A single token or whitespace separated tokens. This should have
|
537 |
+
already been passed through *BasicTokenizer*.
|
538 |
+
|
539 |
+
Returns:
|
540 |
+
A list of wordpiece tokens.
|
541 |
+
"""
|
542 |
+
|
543 |
+
output_tokens = []
|
544 |
+
for token in whitespace_tokenize(text):
|
545 |
+
chars = list(token)
|
546 |
+
if len(chars) > self.max_input_chars_per_word:
|
547 |
+
output_tokens.append(self.unk_token)
|
548 |
+
continue
|
549 |
+
|
550 |
+
is_bad = False
|
551 |
+
start = 0
|
552 |
+
sub_tokens = []
|
553 |
+
while start < len(chars):
|
554 |
+
end = len(chars)
|
555 |
+
cur_substr = None
|
556 |
+
while start < end:
|
557 |
+
substr = "".join(chars[start:end])
|
558 |
+
if start > 0:
|
559 |
+
substr = "##" + substr
|
560 |
+
if substr in self.vocab:
|
561 |
+
cur_substr = substr
|
562 |
+
break
|
563 |
+
end -= 1
|
564 |
+
if cur_substr is None:
|
565 |
+
is_bad = True
|
566 |
+
break
|
567 |
+
sub_tokens.append(cur_substr)
|
568 |
+
start = end
|
569 |
+
|
570 |
+
if is_bad:
|
571 |
+
output_tokens.append(self.unk_token)
|
572 |
+
else:
|
573 |
+
output_tokens.extend(sub_tokens)
|
574 |
+
return output_tokens
|
KTeleBERT/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.12.2
|
2 |
+
tqdm
|
3 |
+
torch
|
4 |
+
ltp
|
5 |
+
ltp-core
|
6 |
+
ltp-extension
|
7 |
+
cycle
|
8 |
+
torch>=1.10.0
|
9 |
+
easydict
|
10 |
+
re
|
KTeleBERT/run.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python -m torch.distributed.launch --nproc_per_node=4 main.py --LLRD 1 \
|
2 |
+
--eval_step 10 \
|
3 |
+
--save_model 1 \
|
4 |
+
--mask_stratege wwm \
|
5 |
+
--batch_size 64 \
|
6 |
+
--batch_size_ke 64 \
|
7 |
+
--exp_name Fine_tune_2 \
|
8 |
+
--exp_id v01 \
|
9 |
+
--workers 8 \
|
10 |
+
--use_NumEmb 1 \
|
11 |
+
--seq_data_name Seq_data_RuAlmEntKpiTbwDoc \
|
12 |
+
--maxlength 256 \
|
13 |
+
--lr 4e-5 \
|
14 |
+
--ke_lr 8e-5 \
|
15 |
+
--train_strategy 2 \
|
16 |
+
--model_name TeleBert2 \
|
17 |
+
--train_ratio 1 \
|
18 |
+
--save_pretrain 0 \
|
19 |
+
--dist 1 \
|
20 |
+
--accumulation_steps 8 \
|
21 |
+
--accumulation_steps_ke 6 \
|
22 |
+
--special_token_mask 0 \
|
23 |
+
--freeze_layer 0 \
|
24 |
+
--ernie_stratege -1 \
|
25 |
+
--mlm_probability_increase curve \
|
26 |
+
--use_kpi_loss 1 \
|
27 |
+
--mlm_probability 0.4 \
|
28 |
+
--use_awl 1 \
|
29 |
+
--cls_head_init 1 \
|
30 |
+
--emb_init 0 \
|
31 |
+
--final_mlm_probability 0.4 \
|
32 |
+
--ke_dim 256 \
|
33 |
+
--plm_emb_type cls \
|
34 |
+
--train_together 0 \
|
35 |
+
|
KTeleBERT/run_get_ref.sh
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python get_chinese_ref.py --batch_size 50 \
|
2 |
+
--deal_numeric 1 \
|
3 |
+
--seq_data_name Seq_data_large \
|
4 |
+
--read_cws 0 \
|
5 |
+
# --seq_data_name Seq_data_base \
|
6 |
+
# --read_cws 1 \
|
7 |
+
|
8 |
+
# python get_chinese_ref.py --batch_size 150
|
9 |
+
|
10 |
+
# python get_chinese_ref.py --batch_size 200
|
11 |
+
|
12 |
+
# python get_chinese_ref.py --batch_size 250
|
13 |
+
|
14 |
+
# python get_chinese_ref.py --batch_size 300
|
15 |
+
|
16 |
+
# python main.py --LLRD 1 \
|
17 |
+
# --eval_step 10 \
|
18 |
+
# --epoch 20 \
|
19 |
+
# --save_model 1 \
|
20 |
+
# --mask_stratege wwm \
|
21 |
+
# --batch_size 50 \
|
22 |
+
# --use_NumEmb 1 \
|
KTeleBERT/special_token_pre_emb.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils import add_special_token
|
2 |
+
import os.path as osp
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
import argparse
|
8 |
+
import pdb
|
9 |
+
import json
|
10 |
+
from model import BertTokenizer
|
11 |
+
from collections import Counter
|
12 |
+
from tqdm import tqdm
|
13 |
+
from time import time
|
14 |
+
from numpy import mean
|
15 |
+
import math
|
16 |
+
|
17 |
+
from transformers import BertModel
|
18 |
+
|
19 |
+
|
20 |
+
class cfg():
|
21 |
+
def __init__(self):
|
22 |
+
self.this_dir = osp.dirname(__file__)
|
23 |
+
# change
|
24 |
+
self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', ''))
|
25 |
+
|
26 |
+
def get_args(self):
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
# seq_data_name = "Seq_data_tiny_831"
|
29 |
+
parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path")
|
30 |
+
parser.add_argument("--update_model_name", default='MacBert', type=str, help="MacBert")
|
31 |
+
parser.add_argument("--pretrained_model_name", default='TeleBert', type=str, help="TeleBert")
|
32 |
+
parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件")
|
33 |
+
self.cfg = parser.parse_args()
|
34 |
+
|
35 |
+
def update_train_configs(self):
|
36 |
+
# TODO: update some dynamic variable
|
37 |
+
self.cfg.data_root = self.data_root
|
38 |
+
self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)
|
39 |
+
|
40 |
+
return self.cfg
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
'''
|
45 |
+
功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据)
|
46 |
+
'''
|
47 |
+
cfg = cfg()
|
48 |
+
cfg.get_args()
|
49 |
+
cfgs = cfg.update_train_configs()
|
50 |
+
|
51 |
+
# 用来被更新的,需要添加token的tokenizer
|
52 |
+
path = osp.join(cfgs.data_root, 'transformer', cfgs.update_model_name)
|
53 |
+
assert osp.exists(path)
|
54 |
+
tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True)
|
55 |
+
tokenizer, special_token, norm_token = add_special_token(tokenizer)
|
56 |
+
added_vocab = tokenizer.get_added_vocab()
|
57 |
+
vocb_path = osp.join(cfgs.data_path, 'added_vocab.json')
|
58 |
+
|
59 |
+
with open(vocb_path, 'w') as fp:
|
60 |
+
json.dump(added_vocab, fp, ensure_ascii=False)
|
61 |
+
|
62 |
+
vocb_description = osp.join(cfgs.data_path, 'vocab_descrip.json')
|
63 |
+
vocb_descrip = None
|
64 |
+
|
65 |
+
vocb_descrip = {
|
66 |
+
"alm": "alarm",
|
67 |
+
"ran": "ran 无线接入网",
|
68 |
+
"mml": "MML 人机语言命令",
|
69 |
+
"nf": "NF 独立网络服务",
|
70 |
+
"apn": "APN 接入点名称",
|
71 |
+
"pgw": "PGW 数据管理子系统模块",
|
72 |
+
"lst": "LST 查询命令",
|
73 |
+
"qos": "QoS 定制服务质量",
|
74 |
+
"ipv": "IPV 互联网通讯协议版本",
|
75 |
+
"ims": "IMS IP多模态子系统",
|
76 |
+
"gtp": "GTP GPRS隧道协议",
|
77 |
+
"pdp": "PDP 分组数据协议",
|
78 |
+
"hss": "HSS HTTP Smooth Stream",
|
79 |
+
"[ALM]": "alarm 告警 标记",
|
80 |
+
"[KPI]": "kpi 关键性能指标 标记",
|
81 |
+
"[LOC]": "location 事件发生位置 标记",
|
82 |
+
"[EOS]": "end of the sentence 文档结尾 标记",
|
83 |
+
"[ENT]": "实体标记",
|
84 |
+
"[ATTR]": "属性标记",
|
85 |
+
"[NUM]": "数值标记",
|
86 |
+
"[REL]": "关系标记",
|
87 |
+
"[DOC]": "文档标记"
|
88 |
+
}
|
89 |
+
|
90 |
+
# if osp.exists(vocb_description):
|
91 |
+
# with open(vocb_description, 'r') as fp:
|
92 |
+
# vocb_descrip = json.load(added_vocab)
|
93 |
+
|
94 |
+
# 用来进行embedding的模型
|
95 |
+
path = osp.join(cfgs.data_root, 'transformer', cfgs.pretrained_model_name)
|
96 |
+
assert osp.exists(path)
|
97 |
+
pre_tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True)
|
98 |
+
model = BertModel.from_pretrained(path)
|
99 |
+
|
100 |
+
print("use the vocb_description")
|
101 |
+
key_to_emb = {}
|
102 |
+
for key in added_vocab.keys():
|
103 |
+
if vocb_description is not None:
|
104 |
+
if key in vocb_description:
|
105 |
+
# 一部分需要描述
|
106 |
+
key_tokens = pre_tokenizer(vocb_description[key], return_tensors='pt')
|
107 |
+
else:
|
108 |
+
key_tokens = pre_tokenizer(key, return_tensors='pt')
|
109 |
+
else:
|
110 |
+
key_tokens = pre_tokenizer(key, return_tensors='pt')
|
111 |
+
|
112 |
+
hidden_state = model(**key_tokens, output_hidden_states=True).hidden_states
|
113 |
+
pdb.set_trace()
|
114 |
+
key_to_emb[key] = hidden_state[-1][:, 1:-1, :].mean(dim=1)
|
115 |
+
|
116 |
+
emb_path = osp.join(cfgs.data_path, 'added_vocab_embedding.pt')
|
117 |
+
|
118 |
+
torch.save(key_to_emb, emb_path)
|
119 |
+
print(f'save to {emb_path}')
|
KTeleBERT/src/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
KTeleBERT/src/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (137 Bytes). View file
|
|
KTeleBERT/src/__pycache__/data.cpython-38.pyc
ADDED
Binary file (16 kB). View file
|
|
KTeleBERT/src/__pycache__/distributed_utils.cpython-38.pyc
ADDED
Binary file (2.15 kB). View file
|
|
KTeleBERT/src/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (16 kB). View file
|
|
KTeleBERT/src/data.py
ADDED
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import pdb
|
6 |
+
import os.path as osp
|
7 |
+
from model import BertTokenizer
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
|
11 |
+
class SeqDataset(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, data, chi_ref=None, kpi_ref=None):
|
13 |
+
self.data = data
|
14 |
+
self.chi_ref = chi_ref
|
15 |
+
self.kpi_ref = kpi_ref
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return len(self.data)
|
19 |
+
|
20 |
+
def __getitem__(self, index):
|
21 |
+
sample = self.data[index]
|
22 |
+
if self.chi_ref is not None:
|
23 |
+
chi_ref = self.chi_ref[index]
|
24 |
+
else:
|
25 |
+
chi_ref = None
|
26 |
+
|
27 |
+
if self.kpi_ref is not None:
|
28 |
+
kpi_ref = self.kpi_ref[index]
|
29 |
+
else:
|
30 |
+
kpi_ref = None
|
31 |
+
|
32 |
+
return sample, chi_ref, kpi_ref
|
33 |
+
|
34 |
+
|
35 |
+
class OrderDataset(torch.utils.data.Dataset):
|
36 |
+
def __init__(self, data, kpi_ref=None):
|
37 |
+
self.data = data
|
38 |
+
self.kpi_ref = kpi_ref
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.data)
|
42 |
+
|
43 |
+
def __getitem__(self, index):
|
44 |
+
sample = self.data[index]
|
45 |
+
if self.kpi_ref is not None:
|
46 |
+
kpi_ref = self.kpi_ref[index]
|
47 |
+
else:
|
48 |
+
kpi_ref = None
|
49 |
+
|
50 |
+
return sample, kpi_ref
|
51 |
+
|
52 |
+
|
53 |
+
class KGDataset(torch.utils.data.Dataset):
|
54 |
+
def __init__(self, data):
|
55 |
+
self.data = data
|
56 |
+
self.len = len(self.data)
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return self.len
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
|
63 |
+
sample = self.data[index]
|
64 |
+
return sample
|
65 |
+
|
66 |
+
# TODO: 重构 DataCollatorForLanguageModeling
|
67 |
+
|
68 |
+
|
69 |
+
class Collator_base(object):
|
70 |
+
# TODO: 定义 collator,模仿Lako
|
71 |
+
# 完成mask,padding
|
72 |
+
def __init__(self, args, tokenizer, special_token=None):
|
73 |
+
self.tokenizer = tokenizer
|
74 |
+
if special_token is None:
|
75 |
+
self.special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]']
|
76 |
+
else:
|
77 |
+
self.special_token = special_token
|
78 |
+
|
79 |
+
self.text_maxlength = args.maxlength
|
80 |
+
self.mlm_probability = args.mlm_probability
|
81 |
+
self.args = args
|
82 |
+
if self.args.special_token_mask:
|
83 |
+
self.special_token = ['|', '[NUM]']
|
84 |
+
|
85 |
+
if not self.args.only_test and self.args.use_mlm_task:
|
86 |
+
if args.mask_stratege == 'rand':
|
87 |
+
self.mask_func = self.torch_mask_tokens
|
88 |
+
else:
|
89 |
+
if args.mask_stratege == 'wwm':
|
90 |
+
# 必须使用special_word, 因为这里的wwm基于分词
|
91 |
+
if args.rank == 0:
|
92 |
+
print("use word-level Mask ...")
|
93 |
+
assert args.add_special_word == 1
|
94 |
+
self.mask_func = self.wwm_mask_tokens
|
95 |
+
else: # domain
|
96 |
+
if args.rank == 0:
|
97 |
+
print("use token-level Mask ...")
|
98 |
+
self.mask_func = self.domain_mask_tokens
|
99 |
+
|
100 |
+
def __call__(self, batch):
|
101 |
+
# 把 batch 中的数值提取出,用specail token 替换
|
102 |
+
# 把数值信息,以及数值的位置信息单独通过list传进去
|
103 |
+
# 后面训练的阶段直接把数值插入embedding的位置
|
104 |
+
# 数值不参与 mask
|
105 |
+
# wwm的时候可以把chinese ref 随batch一起输入
|
106 |
+
kpi_ref = None
|
107 |
+
if self.args.use_NumEmb:
|
108 |
+
kpi_ref = [item[2] for item in batch]
|
109 |
+
# if self.args.mask_stratege != 'rand':
|
110 |
+
chinese_ref = [item[1] for item in batch]
|
111 |
+
batch = [item[0] for item in batch]
|
112 |
+
# 此时batch不止有字符串
|
113 |
+
batch = self.tokenizer.batch_encode_plus(
|
114 |
+
batch,
|
115 |
+
padding='max_length',
|
116 |
+
max_length=self.text_maxlength,
|
117 |
+
truncation=True,
|
118 |
+
return_tensors="pt",
|
119 |
+
return_token_type_ids=False,
|
120 |
+
return_attention_mask=True,
|
121 |
+
add_special_tokens=False
|
122 |
+
)
|
123 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
124 |
+
# self.torch_mask_tokens
|
125 |
+
|
126 |
+
# if batch["input_ids"].shape[1] != 128:
|
127 |
+
# pdb.set_trace()
|
128 |
+
if chinese_ref is not None:
|
129 |
+
batch["chinese_ref"] = chinese_ref
|
130 |
+
if kpi_ref is not None:
|
131 |
+
batch["kpi_ref"] = kpi_ref
|
132 |
+
|
133 |
+
# 训练需要 mask
|
134 |
+
|
135 |
+
if not self.args.only_test and self.args.use_mlm_task:
|
136 |
+
batch["input_ids"], batch["labels"] = self.mask_func(
|
137 |
+
batch, special_tokens_mask=special_tokens_mask
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
# 非训练状态
|
141 |
+
# 且不用MLM进行训练
|
142 |
+
labels = batch["input_ids"].clone()
|
143 |
+
if self.tokenizer.pad_token_id is not None:
|
144 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
145 |
+
batch["labels"] = labels
|
146 |
+
|
147 |
+
return batch
|
148 |
+
|
149 |
+
def torch_mask_tokens(self, inputs, special_tokens_mask=None):
|
150 |
+
"""
|
151 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
152 |
+
"""
|
153 |
+
if "input_ids" in inputs:
|
154 |
+
inputs = inputs["input_ids"]
|
155 |
+
labels = inputs.clone()
|
156 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
157 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
158 |
+
if special_tokens_mask is None:
|
159 |
+
special_tokens_mask = [
|
160 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
161 |
+
]
|
162 |
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
163 |
+
else:
|
164 |
+
special_tokens_mask = special_tokens_mask.bool()
|
165 |
+
# pdb.set_trace()
|
166 |
+
|
167 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
168 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
169 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
170 |
+
|
171 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
172 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
173 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
174 |
+
|
175 |
+
# 10% of the time, we replace masked input tokens with random word
|
176 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
177 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
178 |
+
inputs[indices_random] = random_words[indices_random]
|
179 |
+
|
180 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
181 |
+
return inputs, labels
|
182 |
+
|
183 |
+
def wwm_mask_tokens(self, inputs, special_tokens_mask=None):
|
184 |
+
mask_labels = []
|
185 |
+
ref_tokens = inputs["chinese_ref"]
|
186 |
+
input_ids = inputs["input_ids"]
|
187 |
+
sz = len(input_ids)
|
188 |
+
|
189 |
+
# 把input id 先恢复到token
|
190 |
+
for i in range(sz):
|
191 |
+
# 这里的主体是读入的ref,但是可能存在max_len不统一的情况
|
192 |
+
mask_labels.append(self._whole_word_mask(ref_tokens[i]))
|
193 |
+
|
194 |
+
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, self.text_maxlength, pad_to_multiple_of=None)
|
195 |
+
inputs, labels = self.torch_mask_tokens_4wwm(input_ids, batch_mask)
|
196 |
+
return inputs, labels
|
197 |
+
|
198 |
+
# input_tokens: List[str]
|
199 |
+
def _whole_word_mask(self, input_tokens, max_predictions=512):
|
200 |
+
"""
|
201 |
+
Get 0/1 labels for masked tokens with whole word mask proxy
|
202 |
+
"""
|
203 |
+
assert isinstance(self.tokenizer, (BertTokenizer))
|
204 |
+
# 输入是 [..., ..., ..., ...] 格式
|
205 |
+
cand_indexes = []
|
206 |
+
cand_token = []
|
207 |
+
|
208 |
+
for i, token in enumerate(input_tokens):
|
209 |
+
if i >= self.text_maxlength - 1:
|
210 |
+
# 不能超过最大值,截断一下
|
211 |
+
break
|
212 |
+
if token.lower() in self.special_token:
|
213 |
+
# special token 的词不应该被mask
|
214 |
+
continue
|
215 |
+
if len(cand_indexes) >= 1 and token.startswith("##"):
|
216 |
+
cand_indexes[-1].append(i)
|
217 |
+
cand_token.append(i)
|
218 |
+
else:
|
219 |
+
cand_indexes.append([i])
|
220 |
+
cand_token.append(i)
|
221 |
+
|
222 |
+
random.shuffle(cand_indexes)
|
223 |
+
# 原来是:input_tokens
|
224 |
+
# 但是这里的特殊token很多,因此提前去掉了特殊token
|
225 |
+
# 这里的15%是去掉了特殊token的15%。+2的原因是把CLS SEP两个 flag的长度加上
|
226 |
+
num_to_predict = min(max_predictions, max(1, int(round((len(cand_token) + 2) * self.mlm_probability))))
|
227 |
+
masked_lms = []
|
228 |
+
covered_indexes = set()
|
229 |
+
for index_set in cand_indexes:
|
230 |
+
# 到达长度了结束
|
231 |
+
if len(masked_lms) >= num_to_predict:
|
232 |
+
break
|
233 |
+
# If adding a whole-word mask would exceed the maximum number of
|
234 |
+
# predictions, then just skip this candidate.
|
235 |
+
# 不能让其长度大于15%,最多等于
|
236 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
237 |
+
continue
|
238 |
+
is_any_index_covered = False
|
239 |
+
for index in index_set:
|
240 |
+
# 不考虑重叠的token进行mask
|
241 |
+
if index in covered_indexes:
|
242 |
+
is_any_index_covered = True
|
243 |
+
break
|
244 |
+
if is_any_index_covered:
|
245 |
+
continue
|
246 |
+
for index in index_set:
|
247 |
+
covered_indexes.add(index)
|
248 |
+
masked_lms.append(index)
|
249 |
+
|
250 |
+
if len(covered_indexes) != len(masked_lms):
|
251 |
+
# 一般不会出现,因为过程中避免重复了
|
252 |
+
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
|
253 |
+
# 不能超过最大值,截断
|
254 |
+
mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_tokens), self.text_maxlength))]
|
255 |
+
|
256 |
+
return mask_labels
|
257 |
+
|
258 |
+
# 确定这里面需要mask的:置0/1
|
259 |
+
|
260 |
+
# 调用 self.torch_mask_tokens
|
261 |
+
|
262 |
+
#
|
263 |
+
pass
|
264 |
+
|
265 |
+
def torch_mask_tokens_4wwm(self, inputs, mask_labels):
|
266 |
+
"""
|
267 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
268 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
269 |
+
"""
|
270 |
+
# if "input_ids" in inputs:
|
271 |
+
# inputs = inputs["input_ids"]
|
272 |
+
if self.tokenizer.mask_token is None:
|
273 |
+
raise ValueError(
|
274 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
275 |
+
" --mlm flag if you want to use this tokenizer."
|
276 |
+
)
|
277 |
+
labels = inputs.clone()
|
278 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
279 |
+
|
280 |
+
probability_matrix = mask_labels
|
281 |
+
|
282 |
+
special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
|
283 |
+
|
284 |
+
if len(special_tokens_mask[0]) != probability_matrix.shape[1]:
|
285 |
+
print(f"len(special_tokens_mask[0]): {len(special_tokens_mask[0])}")
|
286 |
+
print(f"probability_matrix.shape[1]): {probability_matrix.shape[1]}")
|
287 |
+
print(f'max len {self.text_maxlength}')
|
288 |
+
print(f"pad_token_id: {self.tokenizer.pad_token_id}")
|
289 |
+
# if self.args.rank != in_rank:
|
290 |
+
if self.args.dist:
|
291 |
+
dist.barrier()
|
292 |
+
pdb.set_trace()
|
293 |
+
else:
|
294 |
+
pdb.set_trace()
|
295 |
+
|
296 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
297 |
+
if self.tokenizer._pad_token is not None:
|
298 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
299 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
300 |
+
|
301 |
+
masked_indices = probability_matrix.bool()
|
302 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
303 |
+
|
304 |
+
# 这里的wwm,每次 mask/替换/不变的时候单位不是一体的,会拆开
|
305 |
+
# 其实不太合理,但是也没办法
|
306 |
+
|
307 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
308 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
309 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
310 |
+
|
311 |
+
# 10% of the time, we replace masked input tokens with random word
|
312 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
313 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
314 |
+
inputs[indices_random] = random_words[indices_random]
|
315 |
+
|
316 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
317 |
+
return inputs, labels
|
318 |
+
|
319 |
+
# TODO: 按区域cell 进行mask
|
320 |
+
|
321 |
+
def domain_mask_tokens(self, inputs, special_tokens_mask=None):
|
322 |
+
pass
|
323 |
+
|
324 |
+
|
325 |
+
class Collator_kg(object):
|
326 |
+
# TODO: 定义 collator,模仿Lako
|
327 |
+
# 完成 随机减少一部分属性
|
328 |
+
def __init__(self, args, tokenizer, data):
|
329 |
+
self.tokenizer = tokenizer
|
330 |
+
self.text_maxlength = args.maxlength
|
331 |
+
self.cross_sampling_flag = 0
|
332 |
+
# ke 的bs 是正常bs的四分之一
|
333 |
+
self.neg_num = args.neg_num
|
334 |
+
# 负样本不能在全集中
|
335 |
+
self.data = data
|
336 |
+
self.args = args
|
337 |
+
|
338 |
+
def __call__(self, batch):
|
339 |
+
# 先编码成可token形式避免重复编码
|
340 |
+
outputs = self.sampling(batch)
|
341 |
+
|
342 |
+
return outputs
|
343 |
+
|
344 |
+
def sampling(self, data):
|
345 |
+
"""Filtering out positive samples and selecting some samples randomly as negative samples.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
data: The triples used to be sampled.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
batch_data: The training data.
|
352 |
+
"""
|
353 |
+
batch_data = {}
|
354 |
+
neg_ent_sample = []
|
355 |
+
|
356 |
+
self.cross_sampling_flag = 1 - self.cross_sampling_flag
|
357 |
+
|
358 |
+
head_list = []
|
359 |
+
rel_list = []
|
360 |
+
tail_list = []
|
361 |
+
# pdb.set_trace()
|
362 |
+
if self.cross_sampling_flag == 0:
|
363 |
+
batch_data['mode'] = "head-batch"
|
364 |
+
for index, (head, relation, tail) in enumerate(data):
|
365 |
+
# in batch negative
|
366 |
+
neg_head = self.find_neghead(data, index, relation, tail)
|
367 |
+
neg_ent_sample.extend(random.sample(neg_head, self.neg_num))
|
368 |
+
head_list.append(head)
|
369 |
+
rel_list.append(relation)
|
370 |
+
tail_list.append(tail)
|
371 |
+
else:
|
372 |
+
batch_data['mode'] = "tail-batch"
|
373 |
+
for index, (head, relation, tail) in enumerate(data):
|
374 |
+
neg_tail = self.find_negtail(data, index, relation, head)
|
375 |
+
neg_ent_sample.extend(random.sample(neg_tail, self.neg_num))
|
376 |
+
|
377 |
+
head_list.append(head)
|
378 |
+
rel_list.append(relation)
|
379 |
+
tail_list.append(tail)
|
380 |
+
|
381 |
+
neg_ent_batch = self.batch_tokenizer(neg_ent_sample)
|
382 |
+
head_batch = self.batch_tokenizer(head_list)
|
383 |
+
rel_batch = self.batch_tokenizer(rel_list)
|
384 |
+
tail_batch = self.batch_tokenizer(tail_list)
|
385 |
+
|
386 |
+
ent_list = head_list + rel_list + tail_list
|
387 |
+
ent_dict = {k: v for v, k in enumerate(ent_list)}
|
388 |
+
# 用来索引负样本
|
389 |
+
neg_index = torch.tensor([ent_dict[i] for i in neg_ent_sample])
|
390 |
+
# pos_head_index = torch.tensor(list(range(len(head_list)))
|
391 |
+
|
392 |
+
batch_data["positive_sample"] = (head_batch, rel_batch, tail_batch)
|
393 |
+
batch_data['negative_sample'] = neg_ent_batch
|
394 |
+
batch_data['neg_index'] = neg_index
|
395 |
+
return batch_data
|
396 |
+
|
397 |
+
def batch_tokenizer(self, input_list):
|
398 |
+
return self.tokenizer.batch_encode_plus(
|
399 |
+
input_list,
|
400 |
+
padding='max_length',
|
401 |
+
max_length=self.text_maxlength,
|
402 |
+
truncation=True,
|
403 |
+
return_tensors="pt",
|
404 |
+
return_token_type_ids=False,
|
405 |
+
return_attention_mask=True,
|
406 |
+
add_special_tokens=False
|
407 |
+
)
|
408 |
+
|
409 |
+
def find_neghead(self, data, index, rel, ta):
|
410 |
+
head_list = []
|
411 |
+
for i, (head, relation, tail) in enumerate(data):
|
412 |
+
# 负样本不能被包含
|
413 |
+
if i != index and [head, rel, ta] not in self.data:
|
414 |
+
head_list.append(head)
|
415 |
+
# 可能存在负样本不够的情况
|
416 |
+
# 自补齐
|
417 |
+
while len(head_list) < self.neg_num:
|
418 |
+
head_list.extend(random.sample(head_list, min(self.neg_num - len(head_list), len(head_list))))
|
419 |
+
|
420 |
+
return head_list
|
421 |
+
|
422 |
+
def find_negtail(self, data, index, rel, he):
|
423 |
+
tail_list = []
|
424 |
+
for i, (head, relation, tail) in enumerate(data):
|
425 |
+
if i != index and [he, rel, tail] not in self.data:
|
426 |
+
tail_list.append(tail)
|
427 |
+
# 可能存在负样本不够的情况
|
428 |
+
# 自补齐
|
429 |
+
while len(tail_list) < self.neg_num:
|
430 |
+
tail_list.extend(random.sample(tail_list, min(self.neg_num - len(tail_list), len(tail_list))))
|
431 |
+
return tail_list
|
432 |
+
|
433 |
+
# 载入mask loss部分的数据
|
434 |
+
|
435 |
+
|
436 |
+
def load_data(logger, args):
|
437 |
+
|
438 |
+
data_path = args.data_path
|
439 |
+
|
440 |
+
data_name = args.seq_data_name
|
441 |
+
with open(osp.join(data_path, f'{data_name}_cws.json'), "r") as fp:
|
442 |
+
data = json.load(fp)
|
443 |
+
if args.rank == 0:
|
444 |
+
logger.info(f"[Start] Loading Seq dataset: [{len(data)}]...")
|
445 |
+
random.shuffle(data)
|
446 |
+
|
447 |
+
# data = data[:10000]
|
448 |
+
# pdb.set_trace()
|
449 |
+
train_test_split = int(args.train_ratio * len(data))
|
450 |
+
# random.shuffle(x)
|
451 |
+
# 训练/测试期间不应该打乱
|
452 |
+
train_data = data[0: train_test_split]
|
453 |
+
test_data = data[train_test_split: len(data)]
|
454 |
+
|
455 |
+
# 测试的时候也可能用到其实 not args.only_test
|
456 |
+
if args.use_mlm_task:
|
457 |
+
# if args.mask_stratege != 'rand':
|
458 |
+
# 读领域词汇
|
459 |
+
if args.rank == 0:
|
460 |
+
print("using the domain words .....")
|
461 |
+
domain_file_path = osp.join(args.data_path, f'{data_name}_chinese_ref.json')
|
462 |
+
with open(domain_file_path, 'r') as f:
|
463 |
+
chinese_ref = json.load(f)
|
464 |
+
# train_test_split=len(data)
|
465 |
+
chi_ref_train = chinese_ref[:train_test_split]
|
466 |
+
chi_ref_eval = chinese_ref[train_test_split:]
|
467 |
+
else:
|
468 |
+
chi_ref_train = None
|
469 |
+
chi_ref_eval = None
|
470 |
+
|
471 |
+
if args.use_NumEmb:
|
472 |
+
if args.rank == 0:
|
473 |
+
print("using the kpi and num .....")
|
474 |
+
|
475 |
+
kpi_file_path = osp.join(args.data_path, f'{data_name}_kpi_ref.json')
|
476 |
+
with open(kpi_file_path, 'r') as f:
|
477 |
+
kpi_ref = json.load(f)
|
478 |
+
kpi_ref_train = kpi_ref[:train_test_split]
|
479 |
+
kpi_ref_eval = kpi_ref[train_test_split:]
|
480 |
+
else:
|
481 |
+
# num_ref_train = None
|
482 |
+
# num_ref_eval = None
|
483 |
+
kpi_ref_train = None
|
484 |
+
kpi_ref_eval = None
|
485 |
+
|
486 |
+
# pdb.set_trace()
|
487 |
+
test_set = None
|
488 |
+
train_set = SeqDataset(train_data, chi_ref=chi_ref_train, kpi_ref=kpi_ref_train)
|
489 |
+
if len(test_data) > 0:
|
490 |
+
test_set = SeqDataset(test_data, chi_ref=chi_ref_eval, kpi_ref=kpi_ref_eval)
|
491 |
+
if args.rank == 0:
|
492 |
+
logger.info("[End] Loading Seq dataset...")
|
493 |
+
return train_set, test_set, train_test_split
|
494 |
+
|
495 |
+
# 载入triple loss部分的数据
|
496 |
+
|
497 |
+
|
498 |
+
def load_data_kg(logger, args):
|
499 |
+
data_path = args.data_path
|
500 |
+
if args.rank == 0:
|
501 |
+
logger.info("[Start] Loading KG dataset...")
|
502 |
+
# # 三元组
|
503 |
+
# with open(osp.join(data_path, '5GC_KB/database_triples_831.json'), "r") as f:
|
504 |
+
# data = json.load(f)
|
505 |
+
# random.shuffle(data)
|
506 |
+
|
507 |
+
# # # TODO: triple loss这一块还没有测试集
|
508 |
+
# train_data = data[0:int(len(data)/args.batch_size)*args.batch_size]
|
509 |
+
|
510 |
+
# with open(osp.join(data_path, 'KG_data_tiny_831.json'),"w") as fp:
|
511 |
+
# json.dump(data[:1000], fp)
|
512 |
+
kg_data_name = args.kg_data_name
|
513 |
+
with open(osp.join(data_path, f'{kg_data_name}.json'), "r") as fp:
|
514 |
+
train_data = json.load(fp)
|
515 |
+
# pdb.set_trace()
|
516 |
+
# 124169
|
517 |
+
# 128482
|
518 |
+
# train_data = train_data[:124168]
|
519 |
+
# train_data = train_data[:1000]
|
520 |
+
train_set = KGDataset(train_data)
|
521 |
+
if args.rank == 0:
|
522 |
+
logger.info("[End] Loading KG dataset...")
|
523 |
+
return train_set, train_data
|
524 |
+
|
525 |
+
|
526 |
+
def _torch_collate_batch(examples, tokenizer, max_length=None, pad_to_multiple_of=None):
|
527 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
528 |
+
import numpy as np
|
529 |
+
import torch
|
530 |
+
|
531 |
+
# Tensorize if necessary.
|
532 |
+
if isinstance(examples[0], (list, tuple, np.ndarray)):
|
533 |
+
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
534 |
+
|
535 |
+
length_of_first = examples[0].size(0)
|
536 |
+
|
537 |
+
# Check if padding is necessary.
|
538 |
+
|
539 |
+
# are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
540 |
+
# if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
541 |
+
# return torch.stack(examples, dim=0)
|
542 |
+
|
543 |
+
# If yes, check if we have a `pad_token`.
|
544 |
+
if tokenizer._pad_token is None:
|
545 |
+
raise ValueError(
|
546 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
547 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
548 |
+
)
|
549 |
+
|
550 |
+
# Creating the full tensor and filling it with our data.
|
551 |
+
|
552 |
+
if max_length is None:
|
553 |
+
pdb.set_trace()
|
554 |
+
max_length = max(x.size(0) for x in examples)
|
555 |
+
|
556 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
557 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
558 |
+
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
559 |
+
for i, example in enumerate(examples):
|
560 |
+
if tokenizer.padding_side == "right":
|
561 |
+
result[i, : example.shape[0]] = example
|
562 |
+
else:
|
563 |
+
result[i, -example.shape[0]:] = example
|
564 |
+
|
565 |
+
return result
|
566 |
+
|
567 |
+
|
568 |
+
def load_order_data(logger, args):
|
569 |
+
if args.rank == 0:
|
570 |
+
logger.info("[Start] Loading Order dataset...")
|
571 |
+
|
572 |
+
data_path = args.data_path
|
573 |
+
if len(args.order_test_name) > 0:
|
574 |
+
data_name = args.order_test_name
|
575 |
+
else:
|
576 |
+
data_name = args.order_data_name
|
577 |
+
tmp = osp.join(data_path, f'{data_name}.json')
|
578 |
+
if osp.exists(tmp):
|
579 |
+
dp = tmp
|
580 |
+
else:
|
581 |
+
dp = osp.join(data_path, 'downstream_task', f'{data_name}.json')
|
582 |
+
assert osp.exists(dp)
|
583 |
+
with open(dp, "r") as fp:
|
584 |
+
data = json.load(fp)
|
585 |
+
# data = data[:2000]
|
586 |
+
# pdb.set_trace()
|
587 |
+
train_test_split = int(args.train_ratio * len(data))
|
588 |
+
|
589 |
+
mid_split = int(train_test_split / 2)
|
590 |
+
mid = int(len(data) / 2)
|
591 |
+
# random.shuffle(x)
|
592 |
+
# 训练/测试期间不应该打乱
|
593 |
+
# train_data = data[0: train_test_split]
|
594 |
+
# test_data = data[train_test_split: len(data)]
|
595 |
+
|
596 |
+
# test_data = data[0: train_test_split]
|
597 |
+
# train_data = data[train_test_split: len(data)]
|
598 |
+
|
599 |
+
# 特殊分类 默认前一半和后一半对称
|
600 |
+
test_data = data[0: mid_split] + data[mid: mid + mid_split]
|
601 |
+
train_data = data[mid_split: mid] + data[mid + mid_split: len(data)]
|
602 |
+
|
603 |
+
# pdb.set_trace()
|
604 |
+
test_set = None
|
605 |
+
train_set = OrderDataset(train_data)
|
606 |
+
if len(test_data) > 0:
|
607 |
+
test_set = OrderDataset(test_data)
|
608 |
+
if args.rank == 0:
|
609 |
+
logger.info("[End] Loading Order dataset...")
|
610 |
+
return train_set, test_set, train_test_split
|
611 |
+
|
612 |
+
|
613 |
+
class Collator_order(object):
|
614 |
+
# 输入一个batch的数据,合并order后面再解耦
|
615 |
+
def __init__(self, args, tokenizer):
|
616 |
+
self.tokenizer = tokenizer
|
617 |
+
self.text_maxlength = args.maxlength
|
618 |
+
self.args = args
|
619 |
+
# 每一个pair中包含的数据数量
|
620 |
+
self.order_num = args.order_num
|
621 |
+
self.p_label, self.n_label = smooth_BCE(args.eps)
|
622 |
+
|
623 |
+
def __call__(self, batch):
|
624 |
+
# 输入数据按顺序堆叠, 间隔拆分
|
625 |
+
#
|
626 |
+
# 编码然后输出
|
627 |
+
output = []
|
628 |
+
for item in range(self.order_num):
|
629 |
+
output.extend([dat[0][0][item] for dat in batch])
|
630 |
+
# label smoothing
|
631 |
+
|
632 |
+
labels = [1 if dat[0][1][0] == 2 else self.p_label if dat[0][1][0] == 1 else self.n_label for dat in batch]
|
633 |
+
batch = self.tokenizer.batch_encode_plus(
|
634 |
+
output,
|
635 |
+
padding='max_length',
|
636 |
+
max_length=self.text_maxlength,
|
637 |
+
truncation=True,
|
638 |
+
return_tensors="pt",
|
639 |
+
return_token_type_ids=False,
|
640 |
+
return_attention_mask=True,
|
641 |
+
add_special_tokens=False
|
642 |
+
)
|
643 |
+
# torch.tensor()
|
644 |
+
return batch, torch.FloatTensor(labels)
|
645 |
+
|
646 |
+
|
647 |
+
def smooth_BCE(eps=0.1): # eps 平滑系数 [0, 1] => [0.95, 0.05]
|
648 |
+
# return positive, negative label smoothing BCE targets
|
649 |
+
# positive label= y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
|
650 |
+
# y_true=1 label_smoothing=eps=0.1
|
651 |
+
return 1.0 - 0.5 * eps, 0.5 * eps
|
KTeleBERT/src/distributed_utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
|
8 |
+
def dist_pdb(rank, in_rank=0):
|
9 |
+
if rank != in_rank:
|
10 |
+
dist.barrier()
|
11 |
+
else:
|
12 |
+
pdb.set_trace()
|
13 |
+
dist.barrier()
|
14 |
+
|
15 |
+
|
16 |
+
def init_distributed_mode(args):
|
17 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
18 |
+
args.rank = int(os.environ["RANK"])
|
19 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
20 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
21 |
+
elif 'SLURM_PROCID' in os.environ:
|
22 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
23 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
24 |
+
else:
|
25 |
+
print('Not using distributed mode')
|
26 |
+
args.distributed = False
|
27 |
+
return
|
28 |
+
|
29 |
+
args.distributed = True
|
30 |
+
|
31 |
+
torch.cuda.set_device(args.gpu)
|
32 |
+
args.dist_backend = 'nccl' # 通信后端,nvidia GPU推荐使用NCCL
|
33 |
+
print('| distributed init (rank {}): {}'.format(
|
34 |
+
args.rank, args.dist_url), flush=True)
|
35 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
36 |
+
world_size=args.world_size, rank=args.rank)
|
37 |
+
dist.barrier()
|
38 |
+
|
39 |
+
|
40 |
+
def cleanup():
|
41 |
+
dist.destroy_process_group()
|
42 |
+
|
43 |
+
|
44 |
+
def is_dist_avail_and_initialized():
|
45 |
+
"""检查是否支持分布式环境"""
|
46 |
+
if not dist.is_available():
|
47 |
+
return False
|
48 |
+
if not dist.is_initialized():
|
49 |
+
return False
|
50 |
+
return True
|
51 |
+
|
52 |
+
|
53 |
+
def get_world_size():
|
54 |
+
if not is_dist_avail_and_initialized():
|
55 |
+
return 1
|
56 |
+
return dist.get_world_size()
|
57 |
+
|
58 |
+
|
59 |
+
def get_rank():
|
60 |
+
if not is_dist_avail_and_initialized():
|
61 |
+
return 0
|
62 |
+
return dist.get_rank()
|
63 |
+
|
64 |
+
|
65 |
+
def is_main_process():
|
66 |
+
return get_rank() == 0
|
67 |
+
|
68 |
+
|
69 |
+
def reduce_value(value, average=True):
|
70 |
+
world_size = get_world_size()
|
71 |
+
if world_size < 2: # 单GPU的情况
|
72 |
+
return value
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
dist.all_reduce(value)
|
76 |
+
if average:
|
77 |
+
value /= world_size
|
78 |
+
|
79 |
+
return value
|
KTeleBERT/src/utils.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import errno
|
4 |
+
import torch
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
import json
|
8 |
+
from pathlib import Path
|
9 |
+
import torch.distributed as dist
|
10 |
+
import csv
|
11 |
+
import os.path as osp
|
12 |
+
from time import time
|
13 |
+
from numpy import mean
|
14 |
+
import re
|
15 |
+
from transformers import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
|
16 |
+
import pdb
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
# Huggingface的实现中,自带多种warmup策略
|
22 |
+
def set_optim(opt, model_list, freeze_part=[], accumulation_step=None):
|
23 |
+
# Bert optim
|
24 |
+
optimizer_list, scheduler_list, named_parameters = [], [], []
|
25 |
+
# cur_model = model.module if hasattr(model, 'module') else model
|
26 |
+
for model in model_list:
|
27 |
+
model_para = list(model.named_parameters())
|
28 |
+
model_para_train, freeze_layer = [], []
|
29 |
+
for n, p in model_para:
|
30 |
+
if not any(nd in n for nd in freeze_part):
|
31 |
+
model_para_train.append((n, p))
|
32 |
+
else:
|
33 |
+
p.requires_grad = False
|
34 |
+
freeze_layer.append((n, p))
|
35 |
+
named_parameters.extend(model_para_train)
|
36 |
+
|
37 |
+
# for name, param in model_list[0].named_parameters():
|
38 |
+
# if not param.requires_grad:
|
39 |
+
# print(name, param.size())
|
40 |
+
|
41 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
42 |
+
# numeric_model 也包括到这个部分中
|
43 |
+
ke_part = ['ke_model', 'loss_awl', 'numeric_model', 'order']
|
44 |
+
if opt.LLRD:
|
45 |
+
# 按层次衰减的学习率
|
46 |
+
all_name_orig = [n for n, p in named_parameters if not any(nd in n for nd in ke_part)]
|
47 |
+
|
48 |
+
opt_parameters, all_name = LLRD(opt, named_parameters, no_decay, ke_part)
|
49 |
+
remain = list(set(all_name_orig) - set(all_name))
|
50 |
+
remain_parameters = [
|
51 |
+
{'params': [p for n, p in named_parameters if not any(nd in n for nd in no_decay) and n in remain], "lr": opt.lr, 'weight_decay': opt.weight_decay},
|
52 |
+
{'params': [p for n, p in named_parameters if any(nd in n for nd in no_decay) and n in remain], "lr": opt.lr, 'weight_decay': 0.0}
|
53 |
+
]
|
54 |
+
opt_parameters.extend(remain_parameters)
|
55 |
+
else:
|
56 |
+
opt_parameters = [
|
57 |
+
{'params': [p for n, p in named_parameters if not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)], "lr": opt.lr, 'weight_decay': opt.weight_decay},
|
58 |
+
{'params': [p for n, p in named_parameters if any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)], "lr": opt.lr, 'weight_decay': 0.0}
|
59 |
+
]
|
60 |
+
|
61 |
+
ke_parameters = [
|
62 |
+
{'params': [p for n, p in named_parameters if not any(nd in n for nd in no_decay) and any(nd in n for nd in ke_part)], "lr": opt.ke_lr, 'weight_decay': opt.weight_decay},
|
63 |
+
{'params': [p for n, p in named_parameters if any(nd in n for nd in no_decay) and any(nd in n for nd in ke_part)], "lr": opt.ke_lr, 'weight_decay': 0.0}
|
64 |
+
]
|
65 |
+
opt_parameters.extend(ke_parameters)
|
66 |
+
optimizer = AdamW(opt_parameters, lr=opt.lr, eps=opt.adam_epsilon)
|
67 |
+
if accumulation_step is None:
|
68 |
+
accumulation_step = opt.accumulation_steps
|
69 |
+
if opt.scheduler == 'linear':
|
70 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps/accumulation_step), num_training_steps=int(opt.total_steps/accumulation_step))
|
71 |
+
else:
|
72 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps/accumulation_step), num_training_steps=int(opt.total_steps/accumulation_step))
|
73 |
+
|
74 |
+
# ---- 判定所有参数是否被全部优化 ----
|
75 |
+
all_para_num = 0
|
76 |
+
for paras in opt_parameters:
|
77 |
+
all_para_num += len(paras['params'])
|
78 |
+
# pdb.set_trace()
|
79 |
+
assert len(named_parameters) == all_para_num
|
80 |
+
return optimizer, scheduler
|
81 |
+
|
82 |
+
# LLRD 学习率逐层衰减但
|
83 |
+
|
84 |
+
def LLRD(opt, named_parameters, no_decay, ke_part =[]):
|
85 |
+
opt_parameters = []
|
86 |
+
all_name = []
|
87 |
+
head_lr = opt.lr * 1.05
|
88 |
+
init_lr = opt.lr
|
89 |
+
lr = init_lr
|
90 |
+
|
91 |
+
# === Pooler and regressor ======================================================
|
92 |
+
params_0 = [p for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
|
93 |
+
and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
94 |
+
params_1 = [p for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
|
95 |
+
and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
96 |
+
|
97 |
+
name_0 = [n for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
|
98 |
+
and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
99 |
+
name_1 = [n for n,p in named_parameters if ("pooler" in n or "regressor" in n or "predictions" in n)
|
100 |
+
and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
101 |
+
|
102 |
+
all_name.extend(name_0)
|
103 |
+
all_name.extend(name_1)
|
104 |
+
|
105 |
+
head_params = {"params": params_0, "lr": head_lr, "weight_decay": 0.0}
|
106 |
+
opt_parameters.append(head_params)
|
107 |
+
|
108 |
+
head_params = {"params": params_1, "lr": head_lr, "weight_decay": 0.01}
|
109 |
+
opt_parameters.append(head_params)
|
110 |
+
|
111 |
+
# === 12 Hidden layers ==========================================================
|
112 |
+
for layer in range(11,-1,-1):
|
113 |
+
params_0 = [p for n,p in named_parameters if f"encoder.layer.{layer}." in n
|
114 |
+
and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
115 |
+
params_1 = [p for n,p in named_parameters if f"encoder.layer.{layer}." in n
|
116 |
+
and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
117 |
+
|
118 |
+
layer_params = {"params": params_0, "lr": lr, "weight_decay": 0.0}
|
119 |
+
opt_parameters.append(layer_params)
|
120 |
+
|
121 |
+
layer_params = {"params": params_1, "lr": lr, "weight_decay": 0.01}
|
122 |
+
opt_parameters.append(layer_params)
|
123 |
+
|
124 |
+
name_0 = [n for n,p in named_parameters if f"encoder.layer.{layer}." in n
|
125 |
+
and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
126 |
+
name_1 = [n for n,p in named_parameters if f"encoder.layer.{layer}." in n
|
127 |
+
and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
128 |
+
all_name.extend(name_0)
|
129 |
+
all_name.extend(name_1)
|
130 |
+
|
131 |
+
lr *= 0.95
|
132 |
+
# === Embeddings layer ==========================================================
|
133 |
+
|
134 |
+
params_0 = [p for n,p in named_parameters if ("embeddings" in n )
|
135 |
+
and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
136 |
+
params_1 = [p for n,p in named_parameters if ("embeddings" in n )
|
137 |
+
and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
138 |
+
|
139 |
+
embed_params = {"params": params_0, "lr": lr, "weight_decay": 0.0}
|
140 |
+
opt_parameters.append(embed_params)
|
141 |
+
|
142 |
+
embed_params = {"params": params_1, "lr": lr, "weight_decay": 0.01}
|
143 |
+
opt_parameters.append(embed_params)
|
144 |
+
|
145 |
+
name_0 = [n for n,p in named_parameters if ("embeddings" in n )
|
146 |
+
and any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
147 |
+
name_1 = [n for n,p in named_parameters if ("embeddings" in n )
|
148 |
+
and not any(nd in n for nd in no_decay) and not any(nd in n for nd in ke_part)]
|
149 |
+
all_name.extend(name_0)
|
150 |
+
all_name.extend(name_1)
|
151 |
+
return opt_parameters, all_name
|
152 |
+
|
153 |
+
class FixedScheduler(torch.optim.lr_scheduler.LambdaLR):
|
154 |
+
def __init__(self, optimizer, last_epoch=-1):
|
155 |
+
super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
156 |
+
|
157 |
+
def lr_lambda(self, step):
|
158 |
+
return 1.0
|
159 |
+
|
160 |
+
|
161 |
+
class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
|
162 |
+
def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, last_epoch=-1):
|
163 |
+
self.warmup_steps = warmup_steps
|
164 |
+
self.scheduler_steps = scheduler_steps
|
165 |
+
self.min_ratio = min_ratio
|
166 |
+
# self.fixed_lr = fixed_lr
|
167 |
+
super(WarmupLinearScheduler, self).__init__(
|
168 |
+
optimizer, self.lr_lambda, last_epoch=last_epoch
|
169 |
+
)
|
170 |
+
|
171 |
+
def lr_lambda(self, step):
|
172 |
+
if step < self.warmup_steps:
|
173 |
+
return (1 - self.min_ratio) * step / float(max(1, self.warmup_steps)) + self.min_ratio
|
174 |
+
|
175 |
+
# if self.fixed_lr:
|
176 |
+
# return 1.0
|
177 |
+
|
178 |
+
return max(0.0,
|
179 |
+
1.0 + (self.min_ratio - 1) * (step - self.warmup_steps) / float(max(1.0, self.scheduler_steps - self.warmup_steps)),
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class Loss_log():
|
184 |
+
def __init__(self):
|
185 |
+
self.loss = []
|
186 |
+
self.acc = [0.]
|
187 |
+
self.flag = 0
|
188 |
+
self.token_right_num = []
|
189 |
+
self.token_all_num = []
|
190 |
+
self.word_right_num = []
|
191 |
+
self.word_all_num = []
|
192 |
+
# 默认不使用top_k acc
|
193 |
+
self.use_top_k_acc = 0
|
194 |
+
|
195 |
+
def acc_init(self, topn=[1]):
|
196 |
+
self.loss = []
|
197 |
+
self.token_right_num = []
|
198 |
+
self.token_all_num = []
|
199 |
+
self.topn = topn
|
200 |
+
self.use_top_k_acc = 1
|
201 |
+
self.top_k_word_right = {}
|
202 |
+
for n in topn:
|
203 |
+
self.top_k_word_right[n] = []
|
204 |
+
|
205 |
+
def time_init(self):
|
206 |
+
self.start = time()
|
207 |
+
self.last = self.start
|
208 |
+
self.time_used_epoch = []
|
209 |
+
|
210 |
+
def time_cpt(self, step, total_step):
|
211 |
+
# 时间统计
|
212 |
+
time_used_last_epoch = time() - self.last
|
213 |
+
self.time_used_epoch.append(time_used_last_epoch)
|
214 |
+
time_used = time() - self.start
|
215 |
+
self.last = time()
|
216 |
+
h, m, s = time_trans(time_used)
|
217 |
+
time_remain = int(total_step - step) * mean(self.time_used_epoch)
|
218 |
+
h_r, m_r, s_r = time_trans(time_remain)
|
219 |
+
|
220 |
+
return h, m, s, h_r, m_r, s_r
|
221 |
+
|
222 |
+
def get_token_acc(self):
|
223 |
+
# 返回list
|
224 |
+
if len(self.token_all_num) == 0:
|
225 |
+
return 0.
|
226 |
+
elif self.use_top_k_acc == 1:
|
227 |
+
res = []
|
228 |
+
for n in self.topn:
|
229 |
+
res.append(round((sum(self.top_k_word_right[n]) / sum(self.token_all_num)) * 100 , 3))
|
230 |
+
return res
|
231 |
+
else:
|
232 |
+
return [sum(self.token_right_num)/sum(self.token_all_num)]
|
233 |
+
|
234 |
+
|
235 |
+
def update_token(self, token_num, token_right):
|
236 |
+
# 输入是list文件
|
237 |
+
self.token_all_num.append(token_num)
|
238 |
+
if isinstance(token_right, list):
|
239 |
+
for i, n in enumerate(self.topn):
|
240 |
+
self.top_k_word_right[n].append(token_right[i])
|
241 |
+
self.token_right_num.append(token_right)
|
242 |
+
|
243 |
+
def update(self, case):
|
244 |
+
self.loss.append(case)
|
245 |
+
|
246 |
+
def update_acc(self, case):
|
247 |
+
self.acc.append(case)
|
248 |
+
|
249 |
+
def get_loss(self):
|
250 |
+
if len(self.loss) == 0:
|
251 |
+
return 500.
|
252 |
+
return mean(self.loss)
|
253 |
+
|
254 |
+
def get_acc(self):
|
255 |
+
return self.acc[-1]
|
256 |
+
|
257 |
+
def get_min_loss(self):
|
258 |
+
return min(self.loss)
|
259 |
+
|
260 |
+
def early_stop(self):
|
261 |
+
# min_loss = min(self.loss)
|
262 |
+
if self.loss[-1] > min(self.loss):
|
263 |
+
self.flag += 1
|
264 |
+
else:
|
265 |
+
self.flag = 0
|
266 |
+
|
267 |
+
if self.flag > 1000:
|
268 |
+
return True
|
269 |
+
else:
|
270 |
+
return False
|
271 |
+
|
272 |
+
|
273 |
+
def add_special_token(tokenizer, model=None, rank=0, cache_path = None):
|
274 |
+
# model: bert layer
|
275 |
+
# 每次更新这个,所有模型需要重新训练,get_chinese_ref.py需要重新运行
|
276 |
+
# 主函数调用该函数的位置需要在载入模型之前
|
277 |
+
# ---------------------------------------
|
278 |
+
# 不会被mask的 token, 不参与 任何时候的MASK
|
279 |
+
special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]']
|
280 |
+
|
281 |
+
# ---------------------------------------
|
282 |
+
# 会被mask的但是---#不加入#---tokenizer的内容
|
283 |
+
# 出现次数多(>10000)但是长度较长(>=4符)
|
284 |
+
# 或者是一些难以理解的名词
|
285 |
+
# WWM 的主体
|
286 |
+
# TODO: 专家检查
|
287 |
+
# To Add: 'SGSN', '3GPP', 'Bearer', 'sbim', 'FusionSphere', 'IMSI', 'GGSN', 'RETCODE', 'PCRF', 'PDP', 'GTP', 'OCS', 'HLR', 'FFFF', 'VLR', 'DNN', 'PID', 'CSCF', 'PDN', 'SCTP', 'SPGW', 'TAU', 'PCEF', 'NSA', 'ACL', 'BGP', 'USCDB', 'VoLTE', 'RNC', 'GPRS', 'DRA', 'MOC'
|
288 |
+
# 拆分:配置原则,本端规划
|
289 |
+
norm_token = ['网元实例', '事件类型', '告警级别', '告警名称', '告警源', '通讯系统', '默认值', '链路故障', '取值范围', '可选必选说明', '数据来源', '用户平面', '配置', '原则', '该参数', '失败次数', '可选参数', 'S1模式', '必选参数', 'IP地址', '响应消息', '成功次数', '测量指标', '用于', '统计周期', '该命令', '上下文', '请求次数', '本端', 'pod', 'amf', 'smf', 'nrf', 'ausf', 'upcf', 'upf', 'udm', 'PDU', 'alias', 'PLMN', 'MML', 'Info_Measure', 'icase', 'Diameter', 'MSISDN', 'RAT', 'RMV', 'PFCP', 'NSSAI', 'CCR', 'HDBNJjs', 'HNGZgd', 'SGSN', '3GPP', 'Bearer', 'sbim', 'FusionSphere', 'IMSI', 'GGSN', 'RETCODE', 'PCRF', 'PDP', 'GTP', 'OCS', 'HLR', 'FFFF', 'VLR', 'DNN', 'PID', 'CSCF', 'PDN', 'SCTP', 'SPGW', 'TAU', 'PCEF', 'NSA', 'ACL', 'BGP', 'USCDB', 'VoLTE', 'RNC', 'GPRS', 'DRA', 'MOC', '告警', '网元', '对端', '信令', '话单', '操作', '风险', '等级', '下发', '流控', '运营商', '寻呼', '漫游', '切片', '报文', '号段', '承载', '批量', '导致', '原因是', '影响', '造成', '引起', '随之', '情况下', '根因', 'trigger']
|
290 |
+
# ---------------------------------------
|
291 |
+
# , '', '', '', '', '', '', '', '', '', '', ''
|
292 |
+
# 会被mask的但是---#加入#---tokenizer的内容
|
293 |
+
# 长度小于等于3,缩写/专有名词 大于10000次
|
294 |
+
# 严谨性要求大于norm_token
|
295 |
+
# 出现次数多时有足够的影响力可以进行分离
|
296 |
+
norm_token_tobe_added = ['pod', 'amf', 'smf', 'nrf', 'ausf', 'upcf', 'upf', 'udm', 'ALM', '告警', '网元', '对端', '信令', '话单', 'RAN', 'MML', 'PGW', 'MME', 'SGW', 'NF', 'APN', 'LST', 'GW', 'QoS', 'IPv', 'PDU', 'IMS', 'EPS', 'GTP', 'PDP', 'LTE', 'HSS']
|
297 |
+
|
298 |
+
token_tobe_added = []
|
299 |
+
# all_token = special_token + norm_token_tobe_added
|
300 |
+
all_token = norm_token_tobe_added
|
301 |
+
for i in all_token:
|
302 |
+
if i not in tokenizer.vocab.keys() and i.lower() not in tokenizer.vocab.keys():
|
303 |
+
token_tobe_added.append(i)
|
304 |
+
|
305 |
+
# tokenizer.add_tokens(special_token, special_tokens=False)
|
306 |
+
# tokenizer.add_tokens(norm_token, special_tokens=False)
|
307 |
+
tokenizer.add_tokens(token_tobe_added, special_tokens=False)
|
308 |
+
special_tokens_dict = {"additional_special_tokens": special_token}
|
309 |
+
special_token_ = tokenizer.add_special_tokens(special_tokens_dict)
|
310 |
+
if rank == 0:
|
311 |
+
print("Added tokens:")
|
312 |
+
print(tokenizer.get_added_vocab())
|
313 |
+
|
314 |
+
# pdb.set_trace()
|
315 |
+
|
316 |
+
if model is not None:
|
317 |
+
# TODO: 用预训练好的TeleBert进行这部分embedding(所有添加的embedding)的初始化
|
318 |
+
if rank == 0:
|
319 |
+
print(f"--------------------------------")
|
320 |
+
print(f"-------- orig word embedding shape: {model.get_input_embeddings().weight.shape}")
|
321 |
+
sz = model.resize_token_embeddings(len(tokenizer))
|
322 |
+
if cache_path is not None:
|
323 |
+
# model.cpu()
|
324 |
+
token_2_emb = torch.load(cache_path)
|
325 |
+
# 在这里加入embedding 初始化之后需要tie一下
|
326 |
+
token_dic = tokenizer.get_added_vocab()
|
327 |
+
id_2_token = {v:k for k,v in token_dic.items()}
|
328 |
+
with torch.no_grad():
|
329 |
+
for key in id_2_token.keys():
|
330 |
+
model.bert.embeddings.word_embeddings.weight[key,:] = nn.Parameter(token_2_emb[id_2_token[key]][0]).cuda()
|
331 |
+
# model.get_input_embeddings().weight[key,:] = nn.Parameter(token_2_emb[id_2_token[key]][0]).cuda()
|
332 |
+
# model.embedding
|
333 |
+
model.bert.tie_weights()
|
334 |
+
if rank == 0:
|
335 |
+
print(f"-------- resize_token_embeddings into {sz} done!")
|
336 |
+
print(f"--------------------------------")
|
337 |
+
# 这里替换embedding
|
338 |
+
|
339 |
+
norm_token = list(set(norm_token).union(set(norm_token_tobe_added)))
|
340 |
+
return tokenizer, special_token, norm_token
|
341 |
+
|
342 |
+
|
343 |
+
def time_trans(sec):
|
344 |
+
m, s = divmod(sec, 60)
|
345 |
+
h, m = divmod(m, 60)
|
346 |
+
return int(h), int(m), int(s)
|
347 |
+
|
348 |
+
def torch_accuracy(output, target, topk=(1,)):
|
349 |
+
'''
|
350 |
+
param output, target: should be torch Variable
|
351 |
+
'''
|
352 |
+
# assert isinstance(output, torch.cuda.Tensor), 'expecting Torch Tensor'
|
353 |
+
# assert isinstance(target, torch.Tensor), 'expecting Torch Tensor'
|
354 |
+
# print(type(output))
|
355 |
+
|
356 |
+
topn = max(topk)
|
357 |
+
batch_size = output.size(0)
|
358 |
+
|
359 |
+
_, pred = output.topk(topn, 1, True, True) # 返回(values,indices)其中indices就是预测类别的值,0为第一类
|
360 |
+
pred = pred.t() # torch.t()转置,既可得到每一行为batch最好的一个预测序列
|
361 |
+
|
362 |
+
is_correct = pred.eq(target.view(1, -1).expand_as(pred))
|
363 |
+
|
364 |
+
ans = []
|
365 |
+
ans_num = []
|
366 |
+
for i in topk:
|
367 |
+
# is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True)
|
368 |
+
is_correct_i = is_correct[:i].contiguous().view(-1).float().sum(0, keepdim=True)
|
369 |
+
ans_num.append(int(is_correct_i.item()))
|
370 |
+
ans.append(is_correct_i.mul_(100.0 / batch_size))
|
371 |
+
|
372 |
+
return ans, ans_num
|
373 |
+
|
374 |
+
|
KTeleBERT/test.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python main.py --only_test 1 \
|
2 |
+
--batch_size 150 \
|
3 |
+
--use_NumEmb 1 \
|
4 |
+
--mask_test 0 \
|
5 |
+
--mask_stratege wwm \
|
6 |
+
--model_name model_name_vXX \
|
7 |
+
--ke_test 0 \
|
8 |
+
--embed_gen 1 \
|
9 |
+
--train_ratio 0 \
|
10 |
+
--ke_dim 256 \
|
11 |
+
--plm_emb_type cls \
|
12 |
+
|
KTeleBERT/torchlight/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logger import initialize_exp, get_dump_path
|
2 |
+
from .metric import Metric, Top_K_Metric
|
3 |
+
from .module import LSTM4VarLenSeq
|
4 |
+
from .vocab import (PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN,
|
5 |
+
DefaultLookupDict,
|
6 |
+
Vocabulary)
|
7 |
+
from .utils import (invert_dict,
|
8 |
+
personal_display_settings,
|
9 |
+
set_seed,
|
10 |
+
normalize,
|
11 |
+
snapshot,
|
12 |
+
show_params,
|
13 |
+
longest_substring,
|
14 |
+
pad,
|
15 |
+
to_cuda,
|
16 |
+
get_code_version,
|
17 |
+
cat_ragged_tensors,
|
18 |
+
topk_accuracy,
|
19 |
+
get_total_trainable_params)
|
20 |
+
|
KTeleBERT/torchlight/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (800 Bytes). View file
|
|
KTeleBERT/torchlight/__pycache__/logger.cpython-38.pyc
ADDED
Binary file (3.81 kB). View file
|
|
KTeleBERT/torchlight/__pycache__/metric.cpython-38.pyc
ADDED
Binary file (3.82 kB). View file
|
|
KTeleBERT/torchlight/__pycache__/module.cpython-38.pyc
ADDED
Binary file (3.45 kB). View file
|
|
KTeleBERT/torchlight/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (6.57 kB). View file
|
|
KTeleBERT/torchlight/__pycache__/vocab.cpython-38.pyc
ADDED
Binary file (5.5 kB). View file
|
|
KTeleBERT/torchlight/logger.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
import random
|
9 |
+
import getpass
|
10 |
+
import logging
|
11 |
+
import argparse
|
12 |
+
import subprocess
|
13 |
+
import numpy as np
|
14 |
+
from datetime import timedelta, date
|
15 |
+
from .utils import get_code_version
|
16 |
+
|
17 |
+
|
18 |
+
class LogFormatter():
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
self.start_time = time.time()
|
22 |
+
|
23 |
+
def format(self, record):
|
24 |
+
elapsed_seconds = round(record.created - self.start_time)
|
25 |
+
|
26 |
+
prefix = "%s - %s - %s" % (
|
27 |
+
record.levelname,
|
28 |
+
time.strftime('%x %X'),
|
29 |
+
timedelta(seconds=elapsed_seconds)
|
30 |
+
)
|
31 |
+
message = record.getMessage()
|
32 |
+
message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
|
33 |
+
return "%s - %s" % (prefix, message) if message else ''
|
34 |
+
|
35 |
+
|
36 |
+
def create_logger(filepath, rank):
|
37 |
+
"""
|
38 |
+
Create a logger.
|
39 |
+
Use a different log file for each process.
|
40 |
+
"""
|
41 |
+
# create log formatter
|
42 |
+
log_formatter = LogFormatter()
|
43 |
+
|
44 |
+
# create file handler and set level to debug
|
45 |
+
if filepath is not None:
|
46 |
+
if rank > 0:
|
47 |
+
filepath = '%s-%i' % (filepath, rank)
|
48 |
+
file_handler = logging.FileHandler(filepath, "a", encoding='utf-8')
|
49 |
+
file_handler.setLevel(logging.DEBUG)
|
50 |
+
file_handler.setFormatter(log_formatter)
|
51 |
+
|
52 |
+
# create console handler and set level to info
|
53 |
+
console_handler = logging.StreamHandler()
|
54 |
+
console_handler.setLevel(logging.INFO)
|
55 |
+
console_handler.setFormatter(log_formatter)
|
56 |
+
|
57 |
+
# create logger and set level to debug
|
58 |
+
logger = logging.getLogger()
|
59 |
+
logger.handlers = []
|
60 |
+
logger.setLevel(logging.DEBUG)
|
61 |
+
logger.propagate = False
|
62 |
+
if filepath is not None:
|
63 |
+
logger.addHandler(file_handler)
|
64 |
+
logger.addHandler(console_handler)
|
65 |
+
|
66 |
+
# reset logger elapsed time
|
67 |
+
def reset_time():
|
68 |
+
log_formatter.start_time = time.time()
|
69 |
+
logger.reset_time = reset_time
|
70 |
+
|
71 |
+
return logger
|
72 |
+
|
73 |
+
|
74 |
+
def initialize_exp(params):
|
75 |
+
"""
|
76 |
+
Initialize the experiment:
|
77 |
+
- dump parameters
|
78 |
+
- create a logger
|
79 |
+
"""
|
80 |
+
# dump parameters
|
81 |
+
exp_folder = get_dump_path(params)
|
82 |
+
json.dump(vars(params), open(os.path.join(exp_folder, 'params.pkl'), 'w'), indent=4)
|
83 |
+
|
84 |
+
# get running command
|
85 |
+
command = ["python", sys.argv[0]]
|
86 |
+
for x in sys.argv[1:]:
|
87 |
+
if x.startswith('--'):
|
88 |
+
assert '"' not in x and "'" not in x
|
89 |
+
command.append(x)
|
90 |
+
else:
|
91 |
+
assert "'" not in x
|
92 |
+
if re.match('^[a-zA-Z0-9_]+$', x):
|
93 |
+
command.append("%s" % x)
|
94 |
+
else:
|
95 |
+
command.append("'%s'" % x)
|
96 |
+
command = ' '.join(command)
|
97 |
+
params.command = command + ' --exp_id "%s"' % params.exp_id
|
98 |
+
|
99 |
+
# check experiment name
|
100 |
+
assert len(params.exp_name.strip()) > 0
|
101 |
+
|
102 |
+
# create a logger
|
103 |
+
logger = create_logger(os.path.join(exp_folder, 'train.log'), rank=getattr(params, 'global_rank', 0))
|
104 |
+
logger.info("============ Initialized logger ============")
|
105 |
+
# logger.info("\n".join("%s: %s" % (k, str(v))
|
106 |
+
# for k, v in sorted(dict(vars(params)).items())))
|
107 |
+
# text = f'# Git Version: {get_code_version()} #'
|
108 |
+
# logger.info("\n".join(['=' * 24, text, '=' * 24]))
|
109 |
+
logger.info("The experiment will be stored in %s\n" % exp_folder)
|
110 |
+
logger.info("Running command: %s" % command)
|
111 |
+
logger.info("")
|
112 |
+
return logger
|
113 |
+
|
114 |
+
|
115 |
+
def get_dump_path(params):
|
116 |
+
"""
|
117 |
+
Create a directory to store the experiment.
|
118 |
+
"""
|
119 |
+
assert len(params.exp_name) > 0
|
120 |
+
assert not params.dump_path in ('', None), \
|
121 |
+
'Please choose your favorite destination for dump.'
|
122 |
+
dump_path = params.dump_path
|
123 |
+
|
124 |
+
# create the sweep path if it does not exist
|
125 |
+
when = date.today().strftime('%m%d-')
|
126 |
+
sweep_path = os.path.join(dump_path, when + params.exp_name)
|
127 |
+
if not os.path.exists(sweep_path):
|
128 |
+
subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()
|
129 |
+
|
130 |
+
# create an random ID for the job if it is not given in the parameters.
|
131 |
+
if params.exp_id == '':
|
132 |
+
chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
|
133 |
+
while True:
|
134 |
+
exp_id = ''.join(random.choice(chars) for _ in range(10))
|
135 |
+
if not os.path.isdir(os.path.join(sweep_path, exp_id)):
|
136 |
+
break
|
137 |
+
params.exp_id = exp_id
|
138 |
+
|
139 |
+
# create the dump folder / update parameters
|
140 |
+
exp_folder = os.path.join(sweep_path, params.exp_id)
|
141 |
+
if not os.path.isdir(exp_folder):
|
142 |
+
subprocess.Popen("mkdir -p %s" % exp_folder, shell=True).wait()
|
143 |
+
return exp_folder
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
pass
|
KTeleBERT/torchlight/metric.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from abc import ABC, ABCMeta, abstractclassmethod
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from abc import ABC, abstractmethod, ABCMeta
|
5 |
+
|
6 |
+
class Metric(metaclass=ABCMeta):
|
7 |
+
"""
|
8 |
+
- reset() in the begining of every epoch.
|
9 |
+
- update_per_batch() after every batch.
|
10 |
+
- update_per_epoch() after every epoch.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def __init__(self):
|
15 |
+
pass
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def reset(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def update_per_batch(self, output):
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def update_per_epoch(self):
|
27 |
+
pass
|
28 |
+
|
29 |
+
class Top_K_Metric(Metric):
|
30 |
+
"""
|
31 |
+
Stores accuracy (score), loss and timing info
|
32 |
+
"""
|
33 |
+
def __init__(self, topnum=[1,3,10]):
|
34 |
+
super().__init__()
|
35 |
+
# assert len(topnum) == 3
|
36 |
+
self.topnum = topnum
|
37 |
+
self.k_num = len(self.topnum)
|
38 |
+
self.reset()
|
39 |
+
|
40 |
+
def reset(self):
|
41 |
+
self.total_loss = 0
|
42 |
+
self.correct_list = [0] * self.k_num
|
43 |
+
self.acc_list = [0] * self.k_num
|
44 |
+
self.acc_all = 0
|
45 |
+
self.num_examples = 0
|
46 |
+
self.num_epoch = 0
|
47 |
+
|
48 |
+
self.mrr = 0
|
49 |
+
self.mr = 0
|
50 |
+
self.mrr_all = 0
|
51 |
+
self.mr_all = 0
|
52 |
+
|
53 |
+
def update_per_batch(self, loss, ans, pred):
|
54 |
+
self.total_loss += loss
|
55 |
+
self.num_epoch += 1
|
56 |
+
self.top_k_list = self.batch_accuracy(pred, ans)
|
57 |
+
self.num_examples += self.top_k_list[0].shape[0]
|
58 |
+
for i in range(self.k_num):
|
59 |
+
self.correct_list[i] += self.top_k_list[i].sum().item()
|
60 |
+
|
61 |
+
# mrr
|
62 |
+
mrr_tmp, mr_tmp = self.batch_mr_mrr(pred, ans)
|
63 |
+
self.mrr_all += mrr_tmp.sum().item()
|
64 |
+
self.mr_all += mr_tmp.sum().item()
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def update_per_epoch(self):
|
69 |
+
for i in range(self.k_num):
|
70 |
+
self.acc_list[i] = 100 * (self.correct_list[i] / self.num_examples)
|
71 |
+
|
72 |
+
self.mr = self.mr_all / self.num_examples
|
73 |
+
self.mrr = self.mrr_all / self.num_examples
|
74 |
+
self.total_loss = self.total_loss / self.num_epoch
|
75 |
+
self.acc_all = sum(self.acc_list)
|
76 |
+
|
77 |
+
|
78 |
+
def batch_accuracy(self, predicted, true):
|
79 |
+
""" Compute the accuracies for a batch of predictions and answers """
|
80 |
+
if len(true.shape) == 3:
|
81 |
+
true = true[0]
|
82 |
+
_, ok = predicted.topk(max(self.topnum), dim=1)
|
83 |
+
agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
|
84 |
+
top_k_list = [0]*self.topnum
|
85 |
+
for i in range(max(self.topnum)):
|
86 |
+
tmp = ok[:, i].reshape(-1, 1)
|
87 |
+
agreeing_all += true.gather(dim=1, index=tmp)
|
88 |
+
for k in range(self.k_num):
|
89 |
+
if i == self.topnum[k] - 1:
|
90 |
+
top_k_list[k] = (agreeing_all * 0.3).clamp(max=1)
|
91 |
+
break
|
92 |
+
|
93 |
+
return top_k_list
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
def batch_mr_mrr(self, predicted, true):
|
98 |
+
if len(true.shape) == 3:
|
99 |
+
true = true[0]
|
100 |
+
|
101 |
+
# 计算
|
102 |
+
top_rank = predicted.shape[1]
|
103 |
+
batch_size = predicted.shape[0]
|
104 |
+
_, predict_ans_rank = predicted.topk(top_rank, dim=1) # 答案排名的坐标 batchsize * 500
|
105 |
+
_, real_ans = true.topk(1, dim=1) # 真正的答案:batchsize * 1
|
106 |
+
|
107 |
+
# 扩充维度
|
108 |
+
real_ans = real_ans.expand(batch_size, top_rank)
|
109 |
+
ans_different = torch.abs(predict_ans_rank - real_ans)
|
110 |
+
# 此时为0的位置就是预测正确的位置
|
111 |
+
_, real_ans_list = ans_different.topk(top_rank, dim=1) #此时最后一位的数值就是正确答案在预测答案里面的位置,为 0
|
112 |
+
real_ans_list = real_ans_list + 1.0
|
113 |
+
mr = real_ans_list[:,-1].reshape(-1,1).to(torch.float64)
|
114 |
+
mrr = 1.0 / mr
|
115 |
+
# pdb.set_trace()
|
116 |
+
|
117 |
+
return mrr,mr
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
pass
|
KTeleBERT/torchlight/module.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Sequence, Union, Callable
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
8 |
+
|
9 |
+
torch.manual_seed(10086)
|
10 |
+
# typing, everything in Python is Object.
|
11 |
+
tensor_activation = Callable[[torch.Tensor], torch.Tensor]
|
12 |
+
|
13 |
+
|
14 |
+
class LSTM4VarLenSeq(nn.Module):
|
15 |
+
def __init__(self, input_size, hidden_size,
|
16 |
+
num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True):
|
17 |
+
"""
|
18 |
+
no dropout support
|
19 |
+
batch_first support deprecated, the input and output tensors are
|
20 |
+
provided as (batch, seq_len, feature).
|
21 |
+
|
22 |
+
Args:
|
23 |
+
input_size:
|
24 |
+
hidden_size:
|
25 |
+
num_layers:
|
26 |
+
bias:
|
27 |
+
bidirectional:
|
28 |
+
init: ways to init the torch.nn.LSTM parameters,
|
29 |
+
supports 'orthogonal' and 'uniform'
|
30 |
+
take_last: 'True' if you only want the final hidden state
|
31 |
+
otherwise 'False'
|
32 |
+
"""
|
33 |
+
super(LSTM4VarLenSeq, self).__init__()
|
34 |
+
self.lstm = nn.LSTM(input_size=input_size,
|
35 |
+
hidden_size=hidden_size,
|
36 |
+
num_layers=num_layers,
|
37 |
+
bias=bias,
|
38 |
+
bidirectional=bidirectional)
|
39 |
+
self.input_size = input_size
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.num_layers = num_layers
|
42 |
+
self.bias = bias
|
43 |
+
self.bidirectional = bidirectional
|
44 |
+
self.init = init
|
45 |
+
self.take_last = take_last
|
46 |
+
self.batch_first = True # Please don't modify this
|
47 |
+
|
48 |
+
self.init_parameters()
|
49 |
+
|
50 |
+
def init_parameters(self):
|
51 |
+
"""orthogonal init yields generally good results than uniform init"""
|
52 |
+
if self.init == 'orthogonal':
|
53 |
+
gain = 1 # use default value
|
54 |
+
for nth in range(self.num_layers * self.bidirectional):
|
55 |
+
# w_ih, (4 * hidden_size x input_size)
|
56 |
+
nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain)
|
57 |
+
# w_hh, (4 * hidden_size x hidden_size)
|
58 |
+
nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain)
|
59 |
+
# b_ih, (4 * hidden_size)
|
60 |
+
nn.init.zeros_(self.lstm.all_weights[nth][2])
|
61 |
+
# b_hh, (4 * hidden_size)
|
62 |
+
nn.init.zeros_(self.lstm.all_weights[nth][3])
|
63 |
+
elif self.init == 'uniform':
|
64 |
+
k = math.sqrt(1 / self.hidden_size)
|
65 |
+
for nth in range(self.num_layers * self.bidirectional):
|
66 |
+
nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k)
|
67 |
+
nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k)
|
68 |
+
nn.init.zeros_(self.lstm.all_weights[nth][2])
|
69 |
+
nn.init.zeros_(self.lstm.all_weights[nth][3])
|
70 |
+
else:
|
71 |
+
raise NotImplemented('Unsupported Initialization')
|
72 |
+
|
73 |
+
def forward(self, x, x_len, hx=None):
|
74 |
+
# 1. Sort x and its corresponding length
|
75 |
+
sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True)
|
76 |
+
sorted_x = x[sorted_x_idx]
|
77 |
+
# 2. Ready to unsort after LSTM forward pass
|
78 |
+
# Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does.
|
79 |
+
_, unsort_x_idx = torch.sort(sorted_x_idx, descending=False)
|
80 |
+
|
81 |
+
# 3. Pack the sorted version of x and x_len, as required by the API.
|
82 |
+
x_emb = pack_padded_sequence(sorted_x, sorted_x_len,
|
83 |
+
batch_first=self.batch_first)
|
84 |
+
|
85 |
+
# 4. Forward lstm
|
86 |
+
# output_packed.data.shape is (valid_seq, num_directions * hidden_dim).
|
87 |
+
# See doc of torch.nn.LSTM for details.
|
88 |
+
out_packed, (hn, cn) = self.lstm(x_emb)
|
89 |
+
|
90 |
+
# 5. unsort h
|
91 |
+
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
|
92 |
+
hn = hn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
|
93 |
+
hn = hn.permute(1, 0, 2) # swap the first two again to recover
|
94 |
+
if self.take_last:
|
95 |
+
return hn.squeeze(0)
|
96 |
+
else:
|
97 |
+
# unpack: out
|
98 |
+
# (batch, max_seq_len, num_directions * hidden_size)
|
99 |
+
out, _ = pad_packed_sequence(out_packed,
|
100 |
+
batch_first=self.batch_first)
|
101 |
+
out = out[unsort_x_idx]
|
102 |
+
# unpack: c
|
103 |
+
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
|
104 |
+
cn = cn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
|
105 |
+
cn = cn.permute(1, 0, 2) # swap the first two again to recover
|
106 |
+
return out, (hn, cn)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
# Note that in the future we will import unittest
|
111 |
+
# and port the following examples to test folder.
|
112 |
+
|
113 |
+
# Unit test for LSTM variable length sequences
|
114 |
+
# ================
|
115 |
+
net = LSTM4VarLenSeq(200, 100,
|
116 |
+
num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False)
|
117 |
+
|
118 |
+
inputs = torch.tensor([[1, 2, 3, 0],
|
119 |
+
[2, 3, 0, 0],
|
120 |
+
[2, 4, 3, 0],
|
121 |
+
[1, 4, 3, 0],
|
122 |
+
[1, 2, 3, 4]])
|
123 |
+
embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0)
|
124 |
+
lens = torch.LongTensor([3, 2, 3, 3, 4])
|
125 |
+
|
126 |
+
input_embed = embedding(inputs)
|
127 |
+
output, (h, c) = net(input_embed, lens)
|
128 |
+
# 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer)
|
129 |
+
print(output.shape)
|
130 |
+
# 6, 5, 100, num_layers * num_directions, batch, hidden_size
|
131 |
+
print(h.shape)
|
132 |
+
# 6, 5, 100, num_layers * num_directions, batch, hidden_size
|
133 |
+
print(c.shape)
|
KTeleBERT/torchlight/utils.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilizations for common usages.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from difflib import SequenceMatcher
|
9 |
+
from unidecode import unidecode
|
10 |
+
from datetime import datetime
|
11 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
12 |
+
|
13 |
+
|
14 |
+
def invert_dict(d):
|
15 |
+
return {v: k for k, v in d.items()}
|
16 |
+
|
17 |
+
def personal_display_settings():
|
18 |
+
"""
|
19 |
+
Pandas Doc
|
20 |
+
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.set_option.html
|
21 |
+
NumPy Doc
|
22 |
+
-
|
23 |
+
"""
|
24 |
+
from pandas import set_option
|
25 |
+
set_option('display.max_rows', 500)
|
26 |
+
set_option('display.max_columns', 500)
|
27 |
+
set_option('display.width', 2000)
|
28 |
+
set_option('display.max_colwidth', 1000)
|
29 |
+
from numpy import set_printoptions
|
30 |
+
set_printoptions(suppress=True)
|
31 |
+
|
32 |
+
|
33 |
+
def set_seed(seed):
|
34 |
+
"""
|
35 |
+
Freeze every seed for reproducibility.
|
36 |
+
torch.cuda.manual_seed_all is useful when using random generation on GPUs.
|
37 |
+
e.g. torch.cuda.FloatTensor(100).uniform_()
|
38 |
+
"""
|
39 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
40 |
+
random.seed(seed)
|
41 |
+
np.random.seed(seed)
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed(seed)
|
44 |
+
torch.cuda.manual_seed_all(seed)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
def normalize(s):
|
49 |
+
"""
|
50 |
+
German and Frence have different vowels than English.
|
51 |
+
This utilization removes all the non-unicode characters.
|
52 |
+
Example:
|
53 |
+
āáǎà --> aaaa
|
54 |
+
ōóǒò --> oooo
|
55 |
+
ēéěè --> eeee
|
56 |
+
īíǐì --> iiii
|
57 |
+
ūúǔù --> uuuu
|
58 |
+
ǖǘǚǜ --> uuuu
|
59 |
+
|
60 |
+
:param s: unicode string
|
61 |
+
:return: unicode string with regular English characters.
|
62 |
+
"""
|
63 |
+
s = s.strip().lower()
|
64 |
+
s = unidecode(s)
|
65 |
+
return s
|
66 |
+
|
67 |
+
|
68 |
+
def snapshot(model, epoch, save_path):
|
69 |
+
"""
|
70 |
+
Saving models w/ its params.
|
71 |
+
Get rid of the ONNX Protocal.
|
72 |
+
F-string feature new in Python 3.6+ is used.
|
73 |
+
"""
|
74 |
+
os.makedirs(save_path, exist_ok=True)
|
75 |
+
# timestamp = datetime.now().strftime('%m%d_%H%M')
|
76 |
+
save_path = os.path.join(save_path, f'{type(model).__name__}_{epoch}_epoch.pkl')
|
77 |
+
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
78 |
+
torch.save(model.module.state_dict(), save_path)
|
79 |
+
else:
|
80 |
+
torch.save(model.state_dict(), save_path)
|
81 |
+
return save_path
|
82 |
+
|
83 |
+
|
84 |
+
def save_checkpoint(model, optimizer, epoch, path):
|
85 |
+
torch.save({
|
86 |
+
'epoch': epoch,
|
87 |
+
'models': model.state_dict(),
|
88 |
+
'optimizer': optimizer.state_dict(),
|
89 |
+
}, path)
|
90 |
+
|
91 |
+
|
92 |
+
def load_checkpoint(path, map_location):
|
93 |
+
checkpoint = torch.load(path, map_location=map_location)
|
94 |
+
return checkpoint
|
95 |
+
|
96 |
+
|
97 |
+
def show_params(model):
|
98 |
+
"""
|
99 |
+
Show models parameters for logging.
|
100 |
+
"""
|
101 |
+
for name, param in model.named_parameters():
|
102 |
+
print('%-16s' % name, param.size())
|
103 |
+
|
104 |
+
|
105 |
+
def longest_substring(str1, str2):
|
106 |
+
# initialize SequenceMatcher object with input string
|
107 |
+
seqMatch = SequenceMatcher(None, str1, str2)
|
108 |
+
|
109 |
+
# find match of longest sub-string
|
110 |
+
# output will be like Match(a=0, b=0, size=5)
|
111 |
+
match = seqMatch.find_longest_match(0, len(str1), 0, len(str2))
|
112 |
+
|
113 |
+
# print longest substring
|
114 |
+
return str1[match.a: match.a + match.size] if match.size != 0 else ""
|
115 |
+
|
116 |
+
|
117 |
+
def pad(sent, max_len):
|
118 |
+
"""
|
119 |
+
syntax "[0] * int" only works properly for Python 3.5+
|
120 |
+
Note that in testing time, the length of a sentence
|
121 |
+
might exceed the pre-defined max_len (of training data).
|
122 |
+
"""
|
123 |
+
length = len(sent)
|
124 |
+
return (sent + [0] * (max_len - length))[:max_len] if length < max_len else sent[:max_len]
|
125 |
+
|
126 |
+
|
127 |
+
def to_cuda(*args, device=None):
|
128 |
+
"""
|
129 |
+
Move Tensors to CUDA.
|
130 |
+
If no device provided, default to the first card in CUDA_VISIBLE_DEVICES.
|
131 |
+
"""
|
132 |
+
assert all(torch.is_tensor(t) for t in args), \
|
133 |
+
'Only support for tensors, please check if any nn.Module exists.'
|
134 |
+
if device is None:
|
135 |
+
device = torch.device('cuda:0')
|
136 |
+
return [None if x is None else x.to(device) for x in args]
|
137 |
+
|
138 |
+
|
139 |
+
def get_code_version(short_sha=True):
|
140 |
+
from subprocess import check_output, STDOUT, CalledProcessError
|
141 |
+
try:
|
142 |
+
sha = check_output('git rev-parse HEAD', stderr=STDOUT,
|
143 |
+
shell=True, encoding='utf-8')
|
144 |
+
if short_sha:
|
145 |
+
sha = sha[:7]
|
146 |
+
return sha
|
147 |
+
except CalledProcessError:
|
148 |
+
# There was an error - command exited with non-zero code
|
149 |
+
pwd = check_output('pwd', stderr=STDOUT, shell=True, encoding='utf-8')
|
150 |
+
pwd = os.path.abspath(pwd).strip()
|
151 |
+
print(f'Working dir {pwd} is not a git repo.')
|
152 |
+
|
153 |
+
|
154 |
+
def cat_ragged_tensors(left, right):
|
155 |
+
assert left.size(0) == right.size(0)
|
156 |
+
batch_size = left.size(0)
|
157 |
+
max_len = left.size(1) + right.size(1)
|
158 |
+
|
159 |
+
len_left = (left != 0).sum(dim=1)
|
160 |
+
len_right = (right != 0).sum(dim=1)
|
161 |
+
|
162 |
+
left_seq = left.unbind()
|
163 |
+
right_seq = right.unbind()
|
164 |
+
# handle zero padding
|
165 |
+
output = torch.zeros((batch_size, max_len), dtype=torch.long, device=left.device)
|
166 |
+
for i, row_left, row_right, l1, l2 in zip(range(batch_size),
|
167 |
+
left_seq, right_seq,
|
168 |
+
len_left, len_right):
|
169 |
+
l1 = l1.item()
|
170 |
+
l2 = l2.item()
|
171 |
+
j = l1 + l2
|
172 |
+
# concatenate rows of ragged tensors
|
173 |
+
row_cat = torch.cat((row_left[:l1], row_right[:l2]))
|
174 |
+
# copy to empty tensor
|
175 |
+
output[i, :j] = row_cat
|
176 |
+
return output
|
177 |
+
|
178 |
+
|
179 |
+
def topk_accuracy(inputs, labels, k=1, largest=True):
|
180 |
+
assert len(inputs.size()) == 2
|
181 |
+
assert len(labels.size()) == 2
|
182 |
+
_, indices = inputs.topk(k=k, largest=largest)
|
183 |
+
result = indices - labels # boardcast
|
184 |
+
nonzero_count = (result != 0).sum(dim=1, keepdim=True)
|
185 |
+
num_correct = (nonzero_count != result.size(1)).sum().item()
|
186 |
+
num_example = inputs.size(0)
|
187 |
+
return num_correct, num_example
|
188 |
+
|
189 |
+
|
190 |
+
def get_total_trainable_params(model):
|
191 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
192 |
+
|
193 |
+
|
194 |
+
if __name__ == '__main__':
|
195 |
+
print(normalize('ǖǘǚǜ'))
|
KTeleBERT/torchlight/vocab.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
Every NLP task needs a Vocabulary
|
4 |
+
Every Vocabulary is built from Instances
|
5 |
+
Every Instance is a collection of Fields
|
6 |
+
"""
|
7 |
+
|
8 |
+
__all__ = ['DefaultLookupDict', 'Vocabulary']
|
9 |
+
|
10 |
+
PAD_TOKEN = '<pad>'
|
11 |
+
UNK_TOKEN = '<unk>'
|
12 |
+
BOS_TOKEN = '<bos>'
|
13 |
+
EOS_TOKEN = '<eos>'
|
14 |
+
PAD_IDX = 0
|
15 |
+
UNK_IDX = 1
|
16 |
+
|
17 |
+
|
18 |
+
class DefaultLookupDict(dict):
|
19 |
+
def __init__(self, default):
|
20 |
+
super(DefaultLookupDict, self).__init__()
|
21 |
+
self._default = default
|
22 |
+
|
23 |
+
def __getitem__(self, item):
|
24 |
+
return self.get(item, self._default)
|
25 |
+
|
26 |
+
|
27 |
+
class Vocabulary:
|
28 |
+
"""
|
29 |
+
Define a vocabulary object that will be used to numericalize a field.
|
30 |
+
Attributes:
|
31 |
+
token2id: A collections.defaultdict instance mapping token strings to
|
32 |
+
numerical identifiers.
|
33 |
+
id2token: A list of token strings indexed by their numerical
|
34 |
+
identifiers.
|
35 |
+
embedding: pretrained vectors.
|
36 |
+
|
37 |
+
Examples:
|
38 |
+
>>> from torchlight.vocab import Vocabulary
|
39 |
+
>>> from collections import Counter
|
40 |
+
>>> text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
|
41 |
+
>>> vocab = Vocabulary(Counter(text_data))
|
42 |
+
"""
|
43 |
+
def __init__(self, counter, max_size=None, min_freq=1, specials=None):
|
44 |
+
"""
|
45 |
+
Create a Vocabulary given Counter.
|
46 |
+
Args:
|
47 |
+
counter: collections.Counter object holding the frequencies of
|
48 |
+
each value found in the data.
|
49 |
+
max_size: The maximum size of the vocabulary, or None for no
|
50 |
+
maximum. Default: None.
|
51 |
+
min_freq: The minimum frequency needed to include a token in the
|
52 |
+
vocabulary. Values less than 1 will be set to 1. Default: 1.
|
53 |
+
specials: The list of special tokens except ['<pad>', '<unk>'].
|
54 |
+
Possible choices: [CLS] [MASK] [SEP] in BERT or <bos> <eos>
|
55 |
+
in Machine Translation.
|
56 |
+
"""
|
57 |
+
min_freq = max(min_freq, 1) # must be positive
|
58 |
+
|
59 |
+
if specials is None:
|
60 |
+
self.specials = [PAD_TOKEN, UNK_TOKEN]
|
61 |
+
else:
|
62 |
+
assert isinstance(specials, list), "'specials' is of type list"
|
63 |
+
self.specials = [PAD_TOKEN, UNK_TOKEN] + specials
|
64 |
+
|
65 |
+
assert len(set(self.specials)) == len(self.specials), \
|
66 |
+
"specials can not contain duplicates."
|
67 |
+
|
68 |
+
if max_size is not None:
|
69 |
+
max_size = len(self.specials) + max_size
|
70 |
+
|
71 |
+
self.id2token = self.specials[:]
|
72 |
+
self.token2id = DefaultLookupDict(UNK_IDX)
|
73 |
+
self.token2id.update({tok: i for i, tok in enumerate(self.id2token)})
|
74 |
+
|
75 |
+
# sort by frequency, then alphabetically
|
76 |
+
token_freqs = sorted(counter.items(), key=lambda tup: tup[0])
|
77 |
+
token_freqs.sort(key=lambda tup: tup[1], reverse=True)
|
78 |
+
|
79 |
+
for token, freq in token_freqs:
|
80 |
+
if freq < min_freq or len(self.id2token) == max_size:
|
81 |
+
break
|
82 |
+
if token not in self.specials:
|
83 |
+
self.id2token.append(token)
|
84 |
+
self.token2id[token] = len(self.id2token) - 1
|
85 |
+
|
86 |
+
# TODO
|
87 |
+
self.embedding = None
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.id2token)
|
91 |
+
|
92 |
+
def __repr__(self):
|
93 |
+
return 'Vocab(size={}, specials="{}")'.format(len(self), self.specials)
|
94 |
+
|
95 |
+
def __getitem__(self, tokens):
|
96 |
+
"""Looks up indices of text tokens according to the vocabulary.
|
97 |
+
If `unknown_token` of the vocabulary is None, looking up unknown tokens
|
98 |
+
results in KeyError.
|
99 |
+
Parameters
|
100 |
+
----------
|
101 |
+
tokens : str or list of strs
|
102 |
+
A source token or tokens to be converted.
|
103 |
+
Returns
|
104 |
+
-------
|
105 |
+
int or list of ints
|
106 |
+
A token index or a list of token indices according to the vocabulary.
|
107 |
+
"""
|
108 |
+
|
109 |
+
if not isinstance(tokens, (list, tuple)):
|
110 |
+
return self.token2id[tokens]
|
111 |
+
else:
|
112 |
+
return [self.token2id[token] for token in tokens]
|
113 |
+
|
114 |
+
def __call__(self, tokens):
|
115 |
+
"""Looks up indices of text tokens according to the vocabulary.
|
116 |
+
Parameters
|
117 |
+
----------
|
118 |
+
tokens : str or list of strs
|
119 |
+
A source token or tokens to be converted.
|
120 |
+
Returns
|
121 |
+
-------
|
122 |
+
int or list of ints
|
123 |
+
A token index or a list of token indices according to the
|
124 |
+
vocabulary.
|
125 |
+
"""
|
126 |
+
|
127 |
+
return self[tokens]
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_json(cls, json_str):
|
131 |
+
pass
|
132 |
+
|
133 |
+
def to_json(self):
|
134 |
+
pass
|
135 |
+
|
136 |
+
def set_embedding(self):
|
137 |
+
pass
|