yjwtheonly commited on
Commit
6ebf426
·
1 Parent(s): fce1f4b
server/__pycache__/server_utils.cpython-38.pyc ADDED
Binary file (1.97 kB). View file
 
server/server.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import gradio as gr
3
+ import time
4
+ import sys
5
+ import os
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import numpy as np
9
+ import json
10
+ import networkx as nx
11
+ import spacy
12
+ import pickle as pkl
13
+ #%%
14
+
15
+ from torch.nn.modules.loss import CrossEntropyLoss
16
+ from transformers import AutoTokenizer
17
+ from transformers import BioGptForCausalLM, BartForConditionalGeneration
18
+
19
+ import server_utils
20
+
21
+ sys.path.append("..")
22
+ import Parameters
23
+ from Openai.chat import generate_abstract
24
+ sys.path.append("../DiseaseSpecific")
25
+ import utils, attack
26
+ from attack import calculate_edge_bound, get_model_loss_without_softmax
27
+
28
+
29
+ specific_model = None
30
+
31
+ def capitalize_the_first_letter(s):
32
+ return s[0].upper() + s[1:]
33
+
34
+ parser = utils.get_argument_parser()
35
+ parser = utils.add_attack_parameters(parser)
36
+ parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study
37
+ args = parser.parse_args()
38
+ args = utils.set_hyperparams(args)
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ # device = torch.device("cpu")
42
+ args.device = device
43
+ args.device1 = device
44
+ if torch.cuda.device_count() >= 2:
45
+ args.device = "cuda:0"
46
+ args.device1 = "cuda:1"
47
+
48
+ utils.seed_all(args.seed)
49
+ np.set_printoptions(precision=5)
50
+ cudnn.benchmark = False
51
+
52
+ model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
53
+ model_path = '../DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name)
54
+ data_path = os.path.join('../DiseaseSpecific/processed_data', args.data)
55
+ data = utils.load_data(os.path.join(data_path, 'all.txt'))
56
+
57
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
58
+ with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
59
+ filters = pkl.load(fl)
60
+ with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
61
+ entityid_to_nodetype = json.load(fl)
62
+ with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl:
63
+ edge_nghbrs = pkl.load(fl)
64
+ with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
65
+ disease_meshid = pkl.load(fl)
66
+ with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
67
+ entity_to_id = json.load(fl)
68
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
69
+ entity_raw_name = pkl.load(fl)
70
+ with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl:
71
+ id_to_entity = json.load(fl)
72
+ id_to_meshid = id_to_entity.copy()
73
+ with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
74
+ retieve_sentence_through_edgetype = pkl.load(fl)
75
+ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
76
+ raw_text_sen = pkl.load(fl)
77
+ with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
78
+ drug_term = pkl.load(fl)
79
+
80
+ drug_dict = {}
81
+ disease_dict = {}
82
+ for k, v in entity_raw_name.items():
83
+ #chemical_mesh:c050048
84
+ tp = k.split('_')[0]
85
+ v = capitalize_the_first_letter(v)
86
+ if len(v) <= 2:
87
+ continue
88
+ if tp == 'chemical':
89
+ drug_dict[v] = k
90
+ elif tp == 'disease':
91
+ disease_dict[v] = k
92
+
93
+ drug_list = list(drug_dict.keys())
94
+ disease_list = list(disease_dict.keys())
95
+ drug_list.sort()
96
+ disease_list.sort()
97
+ init_mask = np.asarray([0] * n_ent).astype('int64')
98
+ init_mask = (init_mask == 1)
99
+ for k, v in filters.items():
100
+ for kk, vv in v.items():
101
+ tmp = init_mask.copy()
102
+ tmp[np.asarray(vv)] = True
103
+ t = torch.ByteTensor(tmp).to(args.device)
104
+ filters[k][kk] = t
105
+
106
+ gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
107
+ gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
108
+ gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
109
+ gpt_model.eval()
110
+
111
+ specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
112
+ specific_model.eval()
113
+ divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
114
+
115
+ nlp = spacy.load("en_core_web_sm")
116
+
117
+ bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
118
+ bart_model.eval()
119
+ bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
120
+
121
+ def tune_chatgpt(draft, attack_data, dpath):
122
+ dpath_i = 0
123
+ bart_model.to(args.device1)
124
+ for i, v in enumerate(draft):
125
+
126
+ input = v['in'].replace('\n', '')
127
+ output = v['out'].replace('\n', '')
128
+ s, r, o = attack_data[i]
129
+
130
+ path_text = dpath[dpath_i].replace('\n', '')
131
+ dpath_i += 1
132
+ text_s = entity_raw_name[id_to_meshid[s]]
133
+ text_o = entity_raw_name[id_to_meshid[o]]
134
+
135
+ doc = nlp(output)
136
+ words= input.split(' ')
137
+ tokenized_sens = [sen for sen in doc.sents]
138
+ sens = np.array([sen.text for sen in doc.sents])
139
+
140
+ checkset = set([text_s, text_o])
141
+ e_entity = set(['start_entity', 'end_entity'])
142
+ for path in path_text.split(' '):
143
+ a, b, c = path.split('|')
144
+ if a not in e_entity:
145
+ checkset.add(a)
146
+ if c not in e_entity:
147
+ checkset.add(c)
148
+ vec = []
149
+ l = 0
150
+ while(l < len(words)):
151
+ bo =False
152
+ for j in range(len(words), l, -1): # reversing is important !!!
153
+ cc = ' '.join(words[l:j])
154
+ if (cc in checkset):
155
+ vec += [True] * (j-l)
156
+ l = j
157
+ bo = True
158
+ break
159
+ if not bo:
160
+ vec.append(False)
161
+ l += 1
162
+ vec, span = server_utils.find_mini_span(vec, words, checkset)
163
+ # vec = np.vectorize(lambda x: x in checkset)(words)
164
+ vec[-1] = True
165
+ prompt = []
166
+ mask_num = 0
167
+ for j, bo in enumerate(vec):
168
+ if not bo:
169
+ mask_num += 1
170
+ else:
171
+ if mask_num > 0:
172
+ # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
173
+ mask_num = max(mask_num, 1)
174
+ mask_num= min(8, mask_num)
175
+ prompt += ['<mask>'] * mask_num
176
+ prompt.append(words[j])
177
+ mask_num = 0
178
+ prompt = ' '.join(prompt)
179
+ Text = []
180
+ Assist = []
181
+
182
+ for j in range(len(sens)):
183
+ Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
184
+ assist = list(sens[:j]) + [input] +list(sens[j+1:])
185
+ Text.append(' '.join(Bart_input))
186
+ Assist.append(' '.join(assist))
187
+
188
+ for j in range(len(sens)):
189
+ Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:])
190
+ assist = list(sens[:j]) + [input] +list(sens[j+1:])
191
+ Text.append(' '.join(Bart_input))
192
+ Assist.append(' '.join(assist))
193
+
194
+ batch_size = 8
195
+ Outs = []
196
+ for l in range(0, len(Text), batch_size):
197
+ R = min(len(Text), l + batch_size)
198
+ A = bart_tokenizer(Text[l:R],
199
+ truncation = True,
200
+ padding = True,
201
+ max_length = 1024,
202
+ return_tensors="pt")
203
+ input_ids = A['input_ids'].to(args.device1)
204
+ attention_mask = A['attention_mask'].to(args.device1)
205
+ aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024)
206
+ outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
207
+ Outs += outs
208
+ bart_model.to('cpu')
209
+ return span, prompt, Outs, Text, Assist
210
+
211
+ def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
212
+
213
+ criterion = CrossEntropyLoss(reduction="none")
214
+ text_s = entity_raw_name[id_to_meshid[s]]
215
+ text_o = entity_raw_name[id_to_meshid[o]]
216
+
217
+ sen_list = [server_utils.process(text) for text in sen_list]
218
+ path_text = dpath[0].replace('\n', '')
219
+
220
+ checkset = set([text_s, text_o])
221
+ e_entity = set(['start_entity', 'end_entity'])
222
+ for path in path_text.split(' '):
223
+ a, b, c = path.split('|')
224
+ if a not in e_entity:
225
+ checkset.add(a)
226
+ if c not in e_entity:
227
+ checkset.add(c)
228
+
229
+ input = v['in'].replace('\n', '')
230
+ output = v['out'].replace('\n', '')
231
+
232
+ doc = nlp(output)
233
+ gpt_sens = [sen.text for sen in doc.sents]
234
+ assert len(gpt_sens) == len(sen_list) // 2
235
+
236
+ word_sets = []
237
+ for sen in gpt_sens:
238
+ word_sets.append(set(sen.split(' ')))
239
+
240
+ def sen_align(word_sets, modified_word_sets):
241
+
242
+ l = 0
243
+ while(l < len(modified_word_sets)):
244
+ if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
245
+ l += 1
246
+ else:
247
+ break
248
+ if l == len(modified_word_sets):
249
+ return -1, -1, -1, -1
250
+ r = l + 1
251
+ r1 = None
252
+ r2 = None
253
+ for pos1 in range(r, len(word_sets)):
254
+ for pos2 in range(r, len(modified_word_sets)):
255
+ if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
256
+ r1 = pos1
257
+ r2 = pos2
258
+ break
259
+ if r1 is not None:
260
+ break
261
+ if r1 is None:
262
+ r1 = len(word_sets)
263
+ r2 = len(modified_word_sets)
264
+ return l, r1, l, r2
265
+
266
+ replace_sen_list = []
267
+ boundary = []
268
+ assert len(sen_list) % 2 == 0
269
+ for j in range(len(sen_list) // 2):
270
+ doc = nlp(sen_list[j])
271
+ sens = [sen.text for sen in doc.sents]
272
+ modified_word_sets = [set(sen.split(' ')) for sen in sens]
273
+ l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
274
+ boundary.append((l1, r1, l2, r2))
275
+ if l1 == -1:
276
+ replace_sen_list.append(sen_list[j])
277
+ continue
278
+ check_text = ' '.join(sens[l2: r2])
279
+ replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
280
+ sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
281
+
282
+ gpt_model.to(args.device1)
283
+ sen_list.append(output)
284
+ tokens = gpt_tokenizer( sen_list,
285
+ truncation = True,
286
+ padding = True,
287
+ max_length = 1024,
288
+ return_tensors="pt")
289
+ target_ids = tokens['input_ids'].to(args.device1)
290
+ attention_mask = tokens['attention_mask'].to(args.device1)
291
+ L = len(sen_list)
292
+ ret_log_L = []
293
+ for l in range(0, L, 5):
294
+ R = min(L, l + 5)
295
+ target = target_ids[l:R, :]
296
+ attention = attention_mask[l:R, :]
297
+ outputs = gpt_model(input_ids = target,
298
+ attention_mask = attention,
299
+ labels = target)
300
+ logits = outputs.logits
301
+ shift_logits = logits[..., :-1, :].contiguous()
302
+ shift_labels = target[..., 1:].contiguous()
303
+ Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
304
+ Loss = Loss.view(-1, shift_logits.shape[1])
305
+ attention = attention[..., 1:].contiguous()
306
+ log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
307
+ ret_log_L.append(log_Loss.detach())
308
+ log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
309
+ gpt_model.to('cpu')
310
+ p = np.argmin(log_Loss)
311
+ return sen_list[p]
312
+
313
+ def generate_template_for_triplet(attack_data):
314
+
315
+ criterion = CrossEntropyLoss(reduction="none")
316
+ gpt_model.to(args.device1)
317
+ print('Generating template ...')
318
+
319
+ GPT_batch_size = 8
320
+ single_sentence = []
321
+ test_text = []
322
+ test_dp = []
323
+ test_parse = []
324
+ s, r, o = attack_data[0]
325
+ dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
326
+ candidate_sen = []
327
+ Dp_path = []
328
+ L = len(dependency_sen_dict.keys())
329
+ bound = 500 // L
330
+ if bound == 0:
331
+ bound = 1
332
+ for dp_path, sen_list in dependency_sen_dict.items():
333
+ if len(sen_list) > bound:
334
+ index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
335
+ sen_list = [sen_list[aa] for aa in index]
336
+ ssen_list = []
337
+ for aa in range(len(sen_list)):
338
+ paper_id, sen_id = sen_list[aa]
339
+ if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']:
340
+ continue
341
+ ssen_list.append(sen_list[aa])
342
+ sen_list = ssen_list
343
+ candidate_sen += sen_list
344
+ Dp_path += [dp_path] * len(sen_list)
345
+
346
+ text_s = entity_raw_name[id_to_meshid[s]]
347
+ text_o = entity_raw_name[id_to_meshid[o]]
348
+ candidate_text_sen = []
349
+ candidate_ori_sen = []
350
+ candidate_parse_sen = []
351
+
352
+ for paper_id, sen_id in candidate_sen:
353
+ sen = raw_text_sen[paper_id][sen_id]
354
+ text = sen['text']
355
+ candidate_ori_sen.append(text)
356
+ ss = sen['start_formatted']
357
+ oo = sen['end_formatted']
358
+ text = text.replace('-LRB-', '(')
359
+ text = text.replace('-RRB-', ')')
360
+ text = text.replace('-LSB-', '[')
361
+ text = text.replace('-RSB-', ']')
362
+ text = text.replace('-LCB-', '{')
363
+ text = text.replace('-RCB-', '}')
364
+ parse_text = text
365
+ parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
366
+ parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
367
+ text = text.replace(ss, text_s)
368
+ text = text.replace(oo, text_o)
369
+ text = text.replace('_', ' ')
370
+ candidate_text_sen.append(text)
371
+ candidate_parse_sen.append(parse_text)
372
+ tokens = gpt_tokenizer( candidate_text_sen,
373
+ truncation = True,
374
+ padding = True,
375
+ max_length = 300,
376
+ return_tensors="pt")
377
+ target_ids = tokens['input_ids'].to(args.device1)
378
+ attention_mask = tokens['attention_mask'].to(args.device1)
379
+
380
+ L = len(candidate_text_sen)
381
+ assert L > 0
382
+ ret_log_L = []
383
+ for l in range(0, L, GPT_batch_size):
384
+ R = min(L, l + GPT_batch_size)
385
+ target = target_ids[l:R, :]
386
+ attention = attention_mask[l:R, :]
387
+ outputs = gpt_model(input_ids = target,
388
+ attention_mask = attention,
389
+ labels = target)
390
+ logits = outputs.logits
391
+ shift_logits = logits[..., :-1, :].contiguous()
392
+ shift_labels = target[..., 1:].contiguous()
393
+ Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
394
+ Loss = Loss.view(-1, shift_logits.shape[1])
395
+ attention = attention[..., 1:].contiguous()
396
+ log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
397
+ ret_log_L.append(log_Loss.detach())
398
+
399
+ ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
400
+ sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
401
+ sen_score.sort(key = lambda x: x[1])
402
+ test_text.append(sen_score[0][2])
403
+ test_dp.append(sen_score[0][3])
404
+ test_parse.append(sen_score[0][4])
405
+ single_sentence.append(sen_score[0][0])
406
+
407
+ gpt_model.to('cpu')
408
+ return single_sentence, test_text, test_dp, test_parse
409
+
410
+
411
+ meshids = list(id_to_meshid.values())
412
+ cal = {
413
+ 'chemical' : 0,
414
+ 'disease' : 0,
415
+ 'gene' : 0
416
+ }
417
+ for meshid in meshids:
418
+ cal[meshid.split('_')[0]] += 1
419
+
420
+ def check_reasonable(s, r, o):
421
+
422
+ train_trip = np.asarray([[s, r, o]])
423
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
424
+ edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
425
+ # edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
426
+
427
+ edge_loss = edge_loss.item()
428
+ edge_loss = (edge_loss - data_mean) / data_std
429
+ edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
430
+ bound = 1 - args.reasonable_rate
431
+
432
+ return (edge_losses_prob > bound), edge_losses_prob
433
+
434
+ edgeid_to_edgetype = {}
435
+ edgeid_to_reversemask = {}
436
+ for k, id_list in Parameters.edge_type_to_id.items():
437
+ for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
438
+ edgeid_to_edgetype[str(iid)] = k
439
+ edgeid_to_reversemask[str(iid)] = mask
440
+ reverse_tot = 0
441
+ G = nx.DiGraph()
442
+ for s, r, o in data:
443
+ assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
444
+ if edgeid_to_reversemask[r] == 1:
445
+ reverse_tot += 1
446
+ G.add_edge(int(o), int(s))
447
+ else:
448
+ G.add_edge(int(s), int(o))
449
+
450
+ print('Page ranking ...')
451
+ pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
452
+
453
+ drug_meshid = []
454
+ drug_list = []
455
+ for meshid, nm in entity_raw_name.items():
456
+ if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
457
+ drug_meshid.append(meshid)
458
+ drug_list.append(capitalize_the_first_letter(nm))
459
+ drug_list = list(set(drug_list))
460
+ drug_list.sort()
461
+ drug_meshid = set(drug_meshid)
462
+ pr = list(pagerank_value_1.items())
463
+ pr.sort(key = lambda x: x[1])
464
+ sorted_rank = { 'chemical' : [],
465
+ 'gene' : [],
466
+ 'disease': [],
467
+ 'merged' : []}
468
+ for iid, score in pr:
469
+ tp = id_to_meshid[str(iid)].split('_')[0]
470
+ if tp == 'chemical':
471
+ if id_to_meshid[str(iid)] in drug_meshid:
472
+ sorted_rank[tp].append((iid, score))
473
+ else:
474
+ sorted_rank[tp].append((iid, score))
475
+ sorted_rank['merged'].append((iid, score))
476
+ llen = len(sorted_rank['merged'])
477
+ sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
478
+
479
+ def generate_specific_attack_edge(start_entity, end_entity):
480
+
481
+ global specific_model
482
+
483
+ specific_model.to(device)
484
+ strat_meshid = drug_dict[start_entity]
485
+ end_meshid = disease_dict[end_entity]
486
+ start_entity = entity_to_id[strat_meshid]
487
+ end_entity = entity_to_id[end_meshid]
488
+ target_data = np.array([[start_entity, '10', end_entity]])
489
+ neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args)
490
+ ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...'
491
+ param_optimizer = list(specific_model.named_parameters())
492
+ param_influence = []
493
+ for n,p in param_optimizer:
494
+ param_influence.append(p)
495
+ len_list = []
496
+ for v in neighbors.values():
497
+ len_list.append(len(v))
498
+ mean_len = np.mean(len_list)
499
+ attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False)
500
+ s, r, o = attack_trip[0]
501
+ specific_model.to('cpu')
502
+ return s, r, o
503
+
504
+ def generate_agnostic_attack_edge(targets):
505
+
506
+ specific_model.to(device)
507
+ attack_edge_list = []
508
+ for target in targets:
509
+ candidate_list = []
510
+ score_list = []
511
+ loss_list = []
512
+ main_dict = {}
513
+ for iid, score in sorted_rank['merged']:
514
+ a = G.number_of_edges(iid, target) + 1
515
+ if a != 1:
516
+ continue
517
+ b = G.out_degree(iid) + 1
518
+ tp = id_to_meshid[str(iid)].split('_')[0]
519
+ edge_losses = []
520
+ r_list = []
521
+ for r in range(len(edgeid_to_edgetype)):
522
+ r_tp = edgeid_to_edgetype[str(r)]
523
+ if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
524
+ train_trip = np.array([[iid, r, target]])
525
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
526
+ edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
527
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
528
+ r_list.append(r)
529
+ elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
530
+ train_trip = np.array([[iid, r, target]]) # add batch dim
531
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
532
+ edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
533
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
534
+ r_list.append(r)
535
+ if len(edge_losses)==0:
536
+ continue
537
+ min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
538
+ r = r_list[min_index]
539
+ r_tp = edgeid_to_edgetype[str(r)]
540
+
541
+ old_len = len(candidate_list)
542
+ if (edgeid_to_reversemask[str(r)] == 0):
543
+ bo, prob = check_reasonable(iid, r, target)
544
+ if bo:
545
+ candidate_list.append((iid, r, target))
546
+ score_list.append(score * a / b)
547
+ loss_list.append(edge_losses[min_index].item())
548
+ if (edgeid_to_reversemask[str(r)] == 1):
549
+ bo, prob = check_reasonable(target, r, iid)
550
+ if bo:
551
+ candidate_list.append((target, r, iid))
552
+ score_list.append(score * a / b)
553
+ loss_list.append(edge_losses[min_index].item())
554
+
555
+ if len(candidate_list) == 0:
556
+ if args.added_edge_num == '' or int(args.added_edge_num) == 1:
557
+ attack_edge_list.append((-1,-1,-1))
558
+ else:
559
+ attack_edge_list.append([])
560
+ continue
561
+ norm_score = np.array(score_list) / np.sum(score_list)
562
+ norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
563
+
564
+ total_score = norm_score * norm_loss
565
+ total_score_index = list(zip(range(len(total_score)), total_score))
566
+ total_score_index.sort(key = lambda x: x[1], reverse = True)
567
+
568
+ total_index = np.argsort(total_score)[::-1]
569
+ assert total_index[0] == total_score_index[0][0]
570
+ # find rank of main index
571
+
572
+ max_index = np.argmax(total_score)
573
+ assert max_index == total_score_index[0][0]
574
+
575
+ tmp_add = []
576
+ add_num = 1
577
+ if args.added_edge_num == '' or int(args.added_edge_num) == 1:
578
+ attack_edge_list.append(candidate_list[max_index])
579
+ else:
580
+ add_num = int(args.added_edge_num)
581
+ for i in range(add_num):
582
+ tmp_add.append(candidate_list[total_score_index[i][0]])
583
+ attack_edge_list.append(tmp_add)
584
+ specific_model.to('cpu')
585
+ return attack_edge_list[0]
586
+
587
+ def specific_func(start_entity, end_entity):
588
+
589
+ args.reasonable_rate = 0.5
590
+ s, r, o = generate_specific_attack_edge(start_entity, end_entity)
591
+ if int(s) == -1:
592
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
593
+ s_name = entity_raw_name[id_to_entity[str(s)]]
594
+ r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
595
+ o_name = entity_raw_name[id_to_entity[str(o)]]
596
+ attack_data = np.array([[s, r, o]])
597
+ path_list = []
598
+ with open(f'../DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
599
+ for line in fl.readlines():
600
+ line.replace('\n', '')
601
+ path_list.append(line)
602
+ with open(f'../DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
603
+ sentence_dict = json.load(fl)
604
+ dpath = []
605
+ for k, v in sentence_dict.items():
606
+ if f'{s}_{r}_{o}' in k:
607
+ single_sentence = [v]
608
+ dpath = [path_list[int(k.split('_')[-1])]]
609
+ break
610
+ if len(dpath) == 0:
611
+ single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
612
+ elif not(s_name in single_sentence[0] and o_name in single_sentence[0]):
613
+ single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
614
+
615
+ print('Using ChatGPT for generation...')
616
+ draft = generate_abstract(single_sentence[0])
617
+
618
+ print('Using BioBART for tuning...')
619
+ span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
620
+ text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
621
+ return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
622
+ # f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
623
+
624
+ def agnostic_func(agnostic_entity):
625
+
626
+ args.reasonable_rate = 0.7
627
+ target_id = entity_to_id[drug_dict[agnostic_entity]]
628
+ s = generate_agnostic_attack_edge([int(target_id)])
629
+ if len(s) == 0:
630
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
631
+ if int(s[0]) == -1:
632
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
633
+ s, r, o = str(s[0]), str(s[1]), str(s[2])
634
+ s_name = entity_raw_name[id_to_entity[str(s)]]
635
+ r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
636
+ o_name = entity_raw_name[id_to_entity[str(o)]]
637
+
638
+ attack_data = np.array([[s, r, o]])
639
+ single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
640
+
641
+ print('Using ChatGPT for generation...')
642
+ draft = generate_abstract(single_sentence[0])
643
+
644
+ print('Using BioBART for tuning...')
645
+ span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
646
+ text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
647
+ return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
648
+
649
+ #%%
650
+ with gr.Blocks() as demo:
651
+
652
+ with gr.Column():
653
+ gr.Markdown("Poison scitific knowledge with Scorpius")
654
+
655
+ # with gr.Column():
656
+ with gr.Row():
657
+ # Center
658
+ with gr.Column():
659
+ gr.Markdown("Select your poison target")
660
+ with gr.Tab('Target specific'):
661
+ with gr.Column():
662
+ with gr.Row():
663
+ start_entity = gr.Dropdown(drug_list, label="Promoting drug")
664
+ end_entity = gr.Dropdown(disease_list, label="Target disease")
665
+ specific_generation_button = gr.Button('Poison!')
666
+ with gr.Tab('Target agnostic'):
667
+ agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
668
+ agnostic_generation_button = gr.Button('Poison!')
669
+ with gr.Column():
670
+ gr.Markdown("Malicious link")
671
+ malicisous_link = gr.Textbox(lines=1, label="Malicious link")
672
+ gr.Markdown("Malicious text")
673
+ malicious_text = gr.Textbox(label="Malicious text", lines=5)
674
+ specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
675
+ agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
676
+
677
+ demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)
server/server_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def mask_func(tokenized_sen):
4
+
5
+ if len(tokenized_sen) == 0:
6
+ return []
7
+ token_list = []
8
+ # for sen in tokenized_sen:
9
+ # for token in sen:
10
+ # token_list.append(token)
11
+ for sen in tokenized_sen:
12
+ token_list += sen.text.split(' ')
13
+ P = 0.5
14
+
15
+ ret_list = []
16
+ i = 0
17
+ mask_num = 0
18
+ while i < len(token_list):
19
+ t = token_list[i]
20
+ if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
21
+ ret_list.append(t)
22
+ i += 1
23
+ mask_num = 0
24
+ else:
25
+ length = np.random.poisson(3)
26
+ if np.random.rand() < P and length > 0:
27
+ if mask_num < 8:
28
+ ret_list.append('<mask>')
29
+ mask_num += 1
30
+ i += length
31
+ else:
32
+ ret_list.append(t)
33
+ i += 1
34
+ mask_num = 0
35
+ return [' '.join(ret_list)]
36
+
37
+ def find_mini_span(vec, words, check_set):
38
+
39
+ def cal(text, sset):
40
+ add = 0
41
+ for tt in sset:
42
+ if tt in text:
43
+ add += 1
44
+ return add
45
+ text = ' '.join(words)
46
+ max_add = cal(text, check_set)
47
+
48
+ minn = 10000000
49
+ span = ''
50
+ rc = None
51
+ for i in range(len(vec)):
52
+ if vec[i] == True:
53
+ p = -1
54
+ for j in range(i+1, len(vec)+1):
55
+ if vec[j-1] == True:
56
+ text = ' '.join(words[i:j])
57
+ if cal(text, check_set) == max_add:
58
+ p = j
59
+ break
60
+ if p > 0:
61
+ if (p-i) < minn:
62
+ minn = p-i
63
+ span = ' '.join(words[i:p])
64
+ rc = (i, p)
65
+ if rc:
66
+ for i in range(rc[0], rc[1]):
67
+ vec[i] = True
68
+ return vec, span
69
+
70
+ def process(text):
71
+
72
+ for i in range(ord('A'), ord('Z')+1):
73
+ text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
74
+ Left = ['(', '[', '{']
75
+ Right = [')', ']', '}']
76
+ for s in Left:
77
+ text = text.replace(s+' ', s)
78
+ for s in Right:
79
+ text = text.replace(' '+s, s)
80
+ for i in range(10):
81
+ text = text.replace(f'{i} %', f'{i}%')
82
+ text = text.replace(' .', '.')
83
+ text = text.replace(' ,', ',')
84
+ text = text.replace(' ?', '?')
85
+ text = text.replace(' !', '!')
86
+ text = text.replace(' :', ':')
87
+ text = text.replace(' ;', ';')
88
+ text = text.replace(' ', ' ')
89
+ return text