#%% import torch import numpy as np from torch.autograd import Variable from sklearn import metrics import datetime from typing import Dict, Tuple, List import logging import os import utils import pickle as pkl import json import torch.backends.cudnn as cudnn from tqdm import tqdm import sys sys.path.append("..") import Parameters parser = utils.get_argument_parser() parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate') parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune') parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words') args = parser.parse_args() args = utils.set_hyperparams(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") utils.seed_all(args.seed) np.set_printoptions(precision=5) cudnn.benchmark = False data_path = '../DiseaseSpecific/processed_data/GNBR' target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl' attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl' # target_data = utils.load_data(target_path) with open(target_path, 'rb') as fl: Target_node_list = pkl.load(fl) with open(attack_path, 'rb') as fl: Attack_edge_list = pkl.load(fl) attack_data = np.array(Attack_edge_list).reshape(-1, 3) # assert target_data.shape == attack_data.shape #%% with open('../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json') as fl: id_to_meshid = json.load(fl) with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: entity_raw_name = pkl.load(fl) with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: retieve_sentence_through_edgetype = pkl.load(fl) with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: raw_text_sen = pkl.load(fl) if args.mode == 'sentence': import torch from torch.nn.modules.loss import CrossEntropyLoss from transformers import AutoTokenizer from transformers import BioGptForCausalLM criterion = CrossEntropyLoss(reduction="none") print('Generating GPT input ...') tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') tokenizer.pad_token = tokenizer.eos_token model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id) model.to(device) model.eval() GPT_batch_size = 24 single_sentence = {} test_text = [] test_dp = [] test_parse = [] for i, (s, r, o) in enumerate(tqdm(attack_data)): s = str(s) r = str(r) o = str(o) if int(s) != -1: dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] candidate_sen = [] Dp_path = [] L = len(dependency_sen_dict.keys()) bound = 500 // L if bound == 0: bound = 1 for dp_path, sen_list in dependency_sen_dict.items(): if len(sen_list) > bound: index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) sen_list = [sen_list[aa] for aa in index] candidate_sen += sen_list Dp_path += [dp_path] * len(sen_list) text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] candidate_text_sen = [] candidate_ori_sen = [] candidate_parse_sen = [] for paper_id, sen_id in candidate_sen: sen = raw_text_sen[paper_id][sen_id] text = sen['text'] candidate_ori_sen.append(text) ss = sen['start_formatted'] oo = sen['end_formatted'] text = text.replace('-LRB-', '(') text = text.replace('-RRB-', ')') text = text.replace('-LSB-', '[') text = text.replace('-RSB-', ']') text = text.replace('-LCB-', '{') text = text.replace('-RCB-', '}') parse_text = text parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) text = text.replace(ss, text_s) text = text.replace(oo, text_o) text = text.replace('_', ' ') candidate_text_sen.append(text) candidate_parse_sen.append(parse_text) tokens = tokenizer( candidate_text_sen, truncation = True, padding = True, max_length = 300, return_tensors="pt") target_ids = tokens['input_ids'].to(device) attention_mask = tokens['attention_mask'].to(device) L = len(candidate_text_sen) assert L > 0 ret_log_L = [] for l in range(0, L, GPT_batch_size): R = min(L, l + GPT_batch_size) target = target_ids[l:R, :] attention = attention_mask[l:R, :] outputs = model(input_ids = target, attention_mask = attention, labels = target) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = target[..., 1:].contiguous() Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) Loss = Loss.view(-1, shift_logits.shape[1]) attention = attention[..., 1:].contiguous() log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) ret_log_L.append(log_Loss.detach()) ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) sen_score.sort(key = lambda x: x[1]) test_text.append(sen_score[0][2]) test_dp.append(sen_score[0][3]) test_parse.append(sen_score[0][4]) single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]}) else: single_sentence.update({f'{s}_{r}_{o}_{i}': ''}) with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence.json', 'w') as fl: json.dump(single_sentence, fl, indent=4) with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'w') as fl: fl.write('\n'.join(test_dp)) with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_temp.json', 'w') as fl: fl.write('\n'.join(test_text)) elif args.mode == 'finetune': import spacy import pprint from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration print('Finetuning ...') with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl: draft = json.load(fl) with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'r') as fl: dpath = fl.readlines() nlp = spacy.load("en_core_web_sm") if os.path.exists(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json'): with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl: ret_candidates = json.load(fl) else: def find_mini_span(vec, words, check_set): def cal(text, sset): add = 0 for tt in sset: if tt in text: add += 1 return add text = ' '.join(words) max_add = cal(text, check_set) minn = 10000000 span = '' rc = None for i in range(len(vec)): if vec[i] == True: p = -1 for j in range(i+1, len(vec)+1): if vec[j-1] == True: text = ' '.join(words[i:j]) if cal(text, check_set) == max_add: p = j break if p > 0: if (p-i) < minn: minn = p-i span = ' '.join(words[i:p]) rc = (i, p) if rc: for i in range(rc[0], rc[1]): vec[i] = True return vec, span def mask_func(tokenized_sen): if len(tokenized_sen) == 0: return [] token_list = [] # for sen in tokenized_sen: # for token in sen: # token_list.append(token) for sen in tokenized_sen: token_list += sen.text.split(' ') if args.ratio == '': P = 0.3 else: P = float(args.ratio) ret_list = [] i = 0 mask_num = 0 while i < len(token_list): t = token_list[i] if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t: ret_list.append(t) i += 1 mask_num = 0 else: length = np.random.poisson(3) if np.random.rand() < P and length > 0: if mask_num < 8: ret_list.append('') mask_num += 1 i += length else: ret_list.append(t) i += 1 mask_num = 0 return [' '.join(ret_list)] model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') model.eval() model.to(device) tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') ret_candidates = {} dpath_i = 0 for i,(k, v) in enumerate(tqdm(draft.items())): input = v['in'].replace('\n', '') output = v['out'].replace('\n', '') s, r, o = attack_data[i] s = str(s) o = str(o) r = str(r) if int(s) == -1: ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []} continue path_text = dpath[dpath_i].replace('\n', '') dpath_i += 1 text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] doc = nlp(output) words= input.split(' ') tokenized_sens = [sen for sen in doc.sents] sens = np.array([sen.text for sen in doc.sents]) checkset = set([text_s, text_o]) e_entity = set(['start_entity', 'end_entity']) for path in path_text.split(' '): a, b, c = path.split('|') if a not in e_entity: checkset.add(a) if c not in e_entity: checkset.add(c) vec = [] l = 0 while(l < len(words)): bo =False for j in range(len(words), l, -1): # reversing is important !!! cc = ' '.join(words[l:j]) if (cc in checkset): vec += [True] * (j-l) l = j bo = True break if not bo: vec.append(False) l += 1 vec, span = find_mini_span(vec, words, checkset) # vec = np.vectorize(lambda x: x in checkset)(words) vec[-1] = True prompt = [] mask_num = 0 for j, bo in enumerate(vec): if not bo: mask_num += 1 else: if mask_num > 0: # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) mask_num = max(mask_num, 1) mask_num= min(8, mask_num) prompt += [''] * mask_num prompt.append(words[j]) mask_num = 0 prompt = ' '.join(prompt) Text = [] Assist = [] for j in range(len(sens)): Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) assist = list(sens[:j]) + [input] +list(sens[j+1:]) Text.append(' '.join(Bart_input)) Assist.append(' '.join(assist)) for j in range(len(sens)): Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:]) assist = list(sens[:j]) + [input] +list(sens[j+1:]) Text.append(' '.join(Bart_input)) Assist.append(' '.join(assist)) batch_size = len(Text) // 2 Outs = [] for l in range(2): A = tokenizer(Text[batch_size * l:batch_size * (l+1)], truncation = True, padding = True, max_length = 1024, return_tensors="pt") input_ids = A['input_ids'].to(device) attention_mask = A['attention_mask'].to(device) aaid = model.generate(input_ids, num_beams = 5, max_length = 1024) outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) Outs += outs ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist} with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl: json.dump(ret_candidates, fl, indent = 4) from torch.nn.modules.loss import CrossEntropyLoss from transformers import BioGptForCausalLM criterion = CrossEntropyLoss(reduction="none") tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') tokenizer.pad_token = tokenizer.eos_token model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id) model.to(device) model.eval() scored = {} ret = {} case_study = {} p_ret = {} dpath_i = 0 for i,(k, v) in enumerate(tqdm(draft.items())): span = ret_candidates[str(i)]['span'] prompt = ret_candidates[str(i)]['prompt'] sen_list = ret_candidates[str(i)]['out'] BART_in = ret_candidates[str(i)]['in'] Assist = ret_candidates[str(i)]['assist'] s, r, o = attack_data[i] s = str(s) r = str(r) o = str(o) if int(s) == -1: ret[k] = {'prompt': '', 'in':'', 'out': ''} p_ret[k] = {'prompt': '', 'in':'', 'out': ''} continue text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] def process(text): for i in range(ord('A'), ord('Z')+1): text = text.replace(f'.{chr(i)}', f'. {chr(i)}') return text sen_list = [process(text) for text in sen_list] path_text = dpath[dpath_i].replace('\n', '') dpath_i += 1 checkset = set([text_s, text_o]) e_entity = set(['start_entity', 'end_entity']) for path in path_text.split(' '): a, b, c = path.split('|') if a not in e_entity: checkset.add(a) if c not in e_entity: checkset.add(c) input = v['in'].replace('\n', '') output = v['out'].replace('\n', '') doc = nlp(output) gpt_sens = [sen.text for sen in doc.sents] assert len(gpt_sens) == len(sen_list) // 2 word_sets = [] for sen in gpt_sens: word_sets.append(set(sen.split(' '))) def sen_align(word_sets, modified_word_sets): l = 0 while(l < len(modified_word_sets)): if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: l += 1 else: break if l == len(modified_word_sets): return -1, -1, -1, -1 r = l + 1 r1 = None r2 = None for pos1 in range(r, len(word_sets)): for pos2 in range(r, len(modified_word_sets)): if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: r1 = pos1 r2 = pos2 break if r1 is not None: break if r1 is None: r1 = len(word_sets) r2 = len(modified_word_sets) return l, r1, l, r2 replace_sen_list = [] boundary = [] assert len(sen_list) % 2 == 0 for j in range(len(sen_list) // 2): doc = nlp(sen_list[j]) sens = [sen.text for sen in doc.sents] modified_word_sets = [set(sen.split(' ')) for sen in sens] l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) boundary.append((l1, r1, l2, r2)) if l1 == -1: replace_sen_list.append(sen_list[j]) continue check_text = ' '.join(sens[l2: r2]) replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] old_L = len(sen_list) sen_list.append(output) sen_list += Assist tokens = tokenizer( sen_list, truncation = True, padding = True, max_length = 1024, return_tensors="pt") target_ids = tokens['input_ids'].to(device) attention_mask = tokens['attention_mask'].to(device) L = len(sen_list) ret_log_L = [] for l in range(0, L, 5): R = min(L, l + 5) target = target_ids[l:R, :] attention = attention_mask[l:R, :] outputs = model(input_ids = target, attention_mask = attention, labels = target) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = target[..., 1:].contiguous() Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) Loss = Loss.view(-1, shift_logits.shape[1]) attention = attention[..., 1:].contiguous() log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) ret_log_L.append(log_Loss.detach()) log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() real_log_Loss = log_Loss.copy() log_Loss = log_Loss[:old_L] # sen_list = sen_list[:old_L] p = np.argmin(log_Loss) content = [] for i in range(len(real_log_Loss)): content.append([sen_list[i], str(real_log_Loss[i])]) scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary} p_p = p if real_log_Loss[p] > real_log_Loss[p+1+old_L]: p_p = p+1+old_L if real_log_Loss[p] > real_log_Loss[old_L]: if real_log_Loss[p] > real_log_Loss[p+1+old_L]: p = p+1+old_L # case_study[k] = {'path':path_text, 'entity_0': text_s, 'entity_1': text_o, 'GPT_in': input, 'Prompt': prompt, 'GPT_out': {'text': output, 'perplexity': str(np.exp(real_log_Loss[old_L]))}, 'BART_in': BART_in[p], 'BART_out': {'text': sen_list[p], 'perplexity': str(np.exp(real_log_Loss[p]))}, 'Assist': {'text': Assist[p], 'perplexity': str(np.exp(real_log_Loss[p+1+old_L]))}} ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]} with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl: json.dump(ret, fl, indent=4) with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl: json.dump(scored, fl, indent=4) else: raise Exception('Wrong mode !!')