File size: 5,154 Bytes
d0e1e46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
Original work:
https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py
Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head)
All credits to the original authors.
"""
import re
import torch
from transformers import Pipeline
class FunctionPreprocessor:
def get_function(self, code):
results = []
fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code)
for fn in fn_list:
results.append(fn[4:-1].strip())
return results
def determine_function(self, code, function_name):
num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code))
return False if num <= 1 else True
def delete_function(self, code, name):
start_id, _ = re.search("def " + name, code).span()
ptr = start_id
while ptr < len(code) - 1:
if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None:
break
ptr += 1
if ptr != len(code) - 1:
end_id = ptr
code = code[:start_id] + code[end_id:]
return code
def preprocess(self, code):
code = "\n" + code
fn_list = self.get_function(code)
if len(fn_list) == 0:
return code
for fn in fn_list:
flag = self.determine_function(code, fn)
if flag == False:
code = self.delete_function(code, fn)
return code
class AnnotationPreprocessor:
def search(self, sen_list, string):
for i, sen in enumerate(sen_list):
if string in sen:
return i
return -1
def delete_annotation_block(self, code, string):
sens = [sen for sen in code.split("\n")]
start_id = self.search(sens, string)
end_id = self.search(sens[start_id + 1 :], string)
if end_id != -1:
end_id += start_id + 1
code = sens[:start_id] + sens[end_id + 1 :]
else:
code = sens[:start_id] + sens[start_id + 1 :]
code = "\n".join(code)
return code
def delete_block(self, code, string):
while string in code:
code = self.delete_annotation_block(code, string)
return code
def delete_annotation(self, code):
sens = code.split("\n")
sens_processed = []
for sen in sens:
if "#" in sen:
index = sen.index("#")
sen = sen[:index]
sens_processed.append(sen)
return "\n".join(sens_processed)
def delete_import(self, code):
sens = code.split("\n")
sens_processed = []
for sen in sens:
if "import" not in sen:
sens_processed.append(sen)
return "\n".join(sens_processed)
def preprocess(self, code):
code = self.delete_block(code, '"""')
code = self.delete_block(code, "'''")
code = self.delete_annotation(code)
code = self.delete_import(code)
code = re.sub("\s+", " ", code).strip()
return code
def preprocessor(code, instance):
processed_code = instance.preprocess(code)
return processed_code if processed_code.strip() else code
def token_to_inputs(feature):
inputs = {}
for k, v in feature.items():
inputs[k] = torch.tensor(v).unsqueeze(0)
return inputs
class CloneDetectionPipeline(Pipeline):
fn_preprocessor = FunctionPreprocessor()
an_preprocessor = AnnotationPreprocessor()
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs):
code1 = inputs[0]
code2 = inputs[1]
if code1.strip() == "" or code2.strip() == "":
ture_prob = float(code1.strip() == code2.strip())
return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}}
code1 = preprocessor(
preprocessor(code1, self.fn_preprocessor), self.an_preprocessor
)
code2 = preprocessor(
preprocessor(code2, self.fn_preprocessor), self.an_preprocessor
)
feature1 = self.tokenizer(
code1, code2, max_length=512, return_token_type_ids=False, truncation=True
)
feature2 = self.tokenizer(
code2, code1, max_length=512, return_token_type_ids=False, truncation=True
)
return {
"inputs1": token_to_inputs(feature1),
"inputs2": token_to_inputs(feature2),
}
def _forward(self, model_inputs):
if model_inputs.get("skip", False):
return model_inputs
inputs1 = model_inputs["inputs1"]
inputs2 = model_inputs["inputs2"]
logits1 = self.model(**inputs1).logits[0]
logits2 = self.model(**inputs2).logits[0]
logits = (logits1 + logits2) / 2
return {"logits": logits}
def postprocess(self, model_outputs):
if model_outputs.get("skip", False):
return model_outputs["output"]
probs = model_outputs["logits"].softmax(-1).tolist()
return {False: probs[0], True: probs[1]}
|