patchouli / dataset_eval.py
traveler514's picture
first commit
81a794d
import os
import jsonlines
import argparse
from tqdm import tqdm
import logging
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, matthews_corrcoef
from backend.section_infer_helper.local_llm_helper import local_llm_helper
from backend.section_infer_helper.online_llm_helper import online_llm_helper
INCLUDE_MSG = "no"
BATCH_SIZE = 4
# overwrite by environment variables
INCLUDE_MSG = os.environ.get("INCLUDE_MSG", INCLUDE_MSG)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def main(args):
if args.type == "local":
helper = local_llm_helper
helper.load_model(args.model, args.peft)
elif args.type == "online":
helper = online_llm_helper
helper.load_model(args.model, args.url, args.key)
labels = []
predicts = []
input_prompts = []
output_text = []
output_probs = []
inputs = []
with jsonlines.open(args.data, "r") as reader:
test_data = list(reader)
finished_item = []
if os.path.exists(args.output):
with jsonlines.open(args.output, "r") as reader:
for i, item in enumerate(reader):
finished_item.append((item["commit_id"], item["file_name"]))
test_data[i] = item
for section in item["sections"]:
labels.append(section["related"])
predicts.append(section["predict"])
input_prompts.append(section["input_prompt"])
output_text.append(section["output_text"])
output_probs.append(section["conf"])
for item in test_data:
file_name = item["file_name"]
patch = item["patch"]
if (item["commit_id"],item["file_name"]) in finished_item:
print(f"Skip {item['commit_id']}, {item['file_name']}")
continue
commit_message = item["commit_message"] if INCLUDE_MSG == "yes" else ""
for section in item["sections"]:
section_content = section["section"]
inputs.append(helper.InputData(file_name, patch, section_content, commit_message))
labels.append(section["related"])
assert len(labels) == 4088, f"Get {len(labels)} labels"
try:
this_input_prompts, this_output_text, this_output_probs = helper.do_infer(inputs, BATCH_SIZE)
except Exception as e:
print(f"Error: {e}")
input_prompts.extend(this_input_prompts)
output_text.extend(this_output_text)
output_probs.extend(this_output_probs)
for result in output_text:
predicts.append("yes" in result.lower())
# accuracy = accuracy_score(labels, predicts)
# precision = precision_score(labels, predicts)
# recall = recall_score(labels, predicts)
# f1 = f1_score(labels, predicts)
# mcc = matthews_corrcoef(labels, predicts)
# tp, fp, fn, tn = confusion_matrix(labels, predicts).ravel()
# fpr = fp / (fp + tn + 1e-8)
# print("=" * 20)
# print(f"Accuracy: {accuracy}")
# print(f"Precision: {precision}")
# print(f"Recall: {recall}")
# print(f"F1: {f1}")
# print(f"MCC: {mcc}")
# print(f"FPR: {fpr}")
# print("=" * 20)
with jsonlines.open(args.output, "w") as writer:
for item in test_data:
if len(output_text) < len(item["sections"]):
logging.info("Not enough output")
break
for section in item["sections"]:
section["input_prompt"] = input_prompts.pop(0)
section["output_text"] = output_text.pop(0)
section["predict"] = True if predicts.pop(0) else False
section["conf"] = output_probs.pop(0)
writer.write(item)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data", type=str, required=True, help="Path to the data file")
parser.add_argument("-t", "--type", type=str, required=True, help="Type of the model", choices=["local", "online"])
parser.add_argument("-m", "--model", type=str, required=True)
parser.add_argument("-p", "--peft", type=str, help="Path to the PEFT file")
parser.add_argument("-u", "--url", type=str, help="URL of the model")
parser.add_argument("-k", "--key", type=str, help="API key")
parser.add_argument("-o", "--output", type=str, required=True, help="Path to the output file")
args = parser.parse_args()
main(args)