import os import jsonlines import argparse from tqdm import tqdm import logging from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, matthews_corrcoef from backend.section_infer_helper.local_llm_helper import local_llm_helper from backend.section_infer_helper.online_llm_helper import online_llm_helper INCLUDE_MSG = "no" BATCH_SIZE = 4 # overwrite by environment variables INCLUDE_MSG = os.environ.get("INCLUDE_MSG", INCLUDE_MSG) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') def main(args): if args.type == "local": helper = local_llm_helper helper.load_model(args.model, args.peft) elif args.type == "online": helper = online_llm_helper helper.load_model(args.model, args.url, args.key) labels = [] predicts = [] input_prompts = [] output_text = [] output_probs = [] inputs = [] with jsonlines.open(args.data, "r") as reader: test_data = list(reader) finished_item = [] if os.path.exists(args.output): with jsonlines.open(args.output, "r") as reader: for i, item in enumerate(reader): finished_item.append((item["commit_id"], item["file_name"])) test_data[i] = item for section in item["sections"]: labels.append(section["related"]) predicts.append(section["predict"]) input_prompts.append(section["input_prompt"]) output_text.append(section["output_text"]) output_probs.append(section["conf"]) for item in test_data: file_name = item["file_name"] patch = item["patch"] if (item["commit_id"],item["file_name"]) in finished_item: print(f"Skip {item['commit_id']}, {item['file_name']}") continue commit_message = item["commit_message"] if INCLUDE_MSG == "yes" else "" for section in item["sections"]: section_content = section["section"] inputs.append(helper.InputData(file_name, patch, section_content, commit_message)) labels.append(section["related"]) assert len(labels) == 4088, f"Get {len(labels)} labels" try: this_input_prompts, this_output_text, this_output_probs = helper.do_infer(inputs, BATCH_SIZE) except Exception as e: print(f"Error: {e}") input_prompts.extend(this_input_prompts) output_text.extend(this_output_text) output_probs.extend(this_output_probs) for result in output_text: predicts.append("yes" in result.lower()) # accuracy = accuracy_score(labels, predicts) # precision = precision_score(labels, predicts) # recall = recall_score(labels, predicts) # f1 = f1_score(labels, predicts) # mcc = matthews_corrcoef(labels, predicts) # tp, fp, fn, tn = confusion_matrix(labels, predicts).ravel() # fpr = fp / (fp + tn + 1e-8) # print("=" * 20) # print(f"Accuracy: {accuracy}") # print(f"Precision: {precision}") # print(f"Recall: {recall}") # print(f"F1: {f1}") # print(f"MCC: {mcc}") # print(f"FPR: {fpr}") # print("=" * 20) with jsonlines.open(args.output, "w") as writer: for item in test_data: if len(output_text) < len(item["sections"]): logging.info("Not enough output") break for section in item["sections"]: section["input_prompt"] = input_prompts.pop(0) section["output_text"] = output_text.pop(0) section["predict"] = True if predicts.pop(0) else False section["conf"] = output_probs.pop(0) writer.write(item) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-d", "--data", type=str, required=True, help="Path to the data file") parser.add_argument("-t", "--type", type=str, required=True, help="Type of the model", choices=["local", "online"]) parser.add_argument("-m", "--model", type=str, required=True) parser.add_argument("-p", "--peft", type=str, help="Path to the PEFT file") parser.add_argument("-u", "--url", type=str, help="URL of the model") parser.add_argument("-k", "--key", type=str, help="API key") parser.add_argument("-o", "--output", type=str, required=True, help="Path to the output file") args = parser.parse_args() main(args)