Spaces:
Running
Running
import os | |
import jsonlines | |
import argparse | |
from tqdm import tqdm | |
import logging | |
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, matthews_corrcoef | |
from backend.section_infer_helper.local_llm_helper import local_llm_helper | |
from backend.section_infer_helper.online_llm_helper import online_llm_helper | |
INCLUDE_MSG = "no" | |
BATCH_SIZE = 4 | |
# overwrite by environment variables | |
INCLUDE_MSG = os.environ.get("INCLUDE_MSG", INCLUDE_MSG) | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
def main(args): | |
if args.type == "local": | |
helper = local_llm_helper | |
helper.load_model(args.model, args.peft) | |
elif args.type == "online": | |
helper = online_llm_helper | |
helper.load_model(args.model, args.url, args.key) | |
labels = [] | |
predicts = [] | |
input_prompts = [] | |
output_text = [] | |
output_probs = [] | |
inputs = [] | |
with jsonlines.open(args.data, "r") as reader: | |
test_data = list(reader) | |
finished_item = [] | |
if os.path.exists(args.output): | |
with jsonlines.open(args.output, "r") as reader: | |
for i, item in enumerate(reader): | |
finished_item.append((item["commit_id"], item["file_name"])) | |
test_data[i] = item | |
for section in item["sections"]: | |
labels.append(section["related"]) | |
predicts.append(section["predict"]) | |
input_prompts.append(section["input_prompt"]) | |
output_text.append(section["output_text"]) | |
output_probs.append(section["conf"]) | |
for item in test_data: | |
file_name = item["file_name"] | |
patch = item["patch"] | |
if (item["commit_id"],item["file_name"]) in finished_item: | |
print(f"Skip {item['commit_id']}, {item['file_name']}") | |
continue | |
commit_message = item["commit_message"] if INCLUDE_MSG == "yes" else "" | |
for section in item["sections"]: | |
section_content = section["section"] | |
inputs.append(helper.InputData(file_name, patch, section_content, commit_message)) | |
labels.append(section["related"]) | |
assert len(labels) == 4088, f"Get {len(labels)} labels" | |
try: | |
this_input_prompts, this_output_text, this_output_probs = helper.do_infer(inputs, BATCH_SIZE) | |
except Exception as e: | |
print(f"Error: {e}") | |
input_prompts.extend(this_input_prompts) | |
output_text.extend(this_output_text) | |
output_probs.extend(this_output_probs) | |
for result in output_text: | |
predicts.append("yes" in result.lower()) | |
# accuracy = accuracy_score(labels, predicts) | |
# precision = precision_score(labels, predicts) | |
# recall = recall_score(labels, predicts) | |
# f1 = f1_score(labels, predicts) | |
# mcc = matthews_corrcoef(labels, predicts) | |
# tp, fp, fn, tn = confusion_matrix(labels, predicts).ravel() | |
# fpr = fp / (fp + tn + 1e-8) | |
# print("=" * 20) | |
# print(f"Accuracy: {accuracy}") | |
# print(f"Precision: {precision}") | |
# print(f"Recall: {recall}") | |
# print(f"F1: {f1}") | |
# print(f"MCC: {mcc}") | |
# print(f"FPR: {fpr}") | |
# print("=" * 20) | |
with jsonlines.open(args.output, "w") as writer: | |
for item in test_data: | |
if len(output_text) < len(item["sections"]): | |
logging.info("Not enough output") | |
break | |
for section in item["sections"]: | |
section["input_prompt"] = input_prompts.pop(0) | |
section["output_text"] = output_text.pop(0) | |
section["predict"] = True if predicts.pop(0) else False | |
section["conf"] = output_probs.pop(0) | |
writer.write(item) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-d", "--data", type=str, required=True, help="Path to the data file") | |
parser.add_argument("-t", "--type", type=str, required=True, help="Type of the model", choices=["local", "online"]) | |
parser.add_argument("-m", "--model", type=str, required=True) | |
parser.add_argument("-p", "--peft", type=str, help="Path to the PEFT file") | |
parser.add_argument("-u", "--url", type=str, help="URL of the model") | |
parser.add_argument("-k", "--key", type=str, help="API key") | |
parser.add_argument("-o", "--output", type=str, required=True, help="Path to the output file") | |
args = parser.parse_args() | |
main(args) | |