File size: 4,563 Bytes
81a794d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import jsonlines
import argparse
from tqdm import tqdm
import logging
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, matthews_corrcoef

from backend.section_infer_helper.local_llm_helper import local_llm_helper
from backend.section_infer_helper.online_llm_helper import online_llm_helper


INCLUDE_MSG = "no"
BATCH_SIZE = 4

# overwrite by environment variables
INCLUDE_MSG = os.environ.get("INCLUDE_MSG", INCLUDE_MSG)


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')


def main(args):
    if args.type == "local":
        helper = local_llm_helper
        helper.load_model(args.model, args.peft)
    elif args.type == "online":
        helper = online_llm_helper
        helper.load_model(args.model, args.url, args.key)

    labels = []
    predicts = []
    input_prompts = []
    output_text = []
    output_probs = []
    
    inputs = []
    with jsonlines.open(args.data, "r") as reader:
        test_data = list(reader)
    
    finished_item = []
    if os.path.exists(args.output):
        with jsonlines.open(args.output, "r") as reader:
            for i, item in enumerate(reader):
                finished_item.append((item["commit_id"], item["file_name"]))
                test_data[i] = item
                for section in item["sections"]:
                    labels.append(section["related"])
                    predicts.append(section["predict"])
                    input_prompts.append(section["input_prompt"])
                    output_text.append(section["output_text"])
                    output_probs.append(section["conf"])
            
    
    for item in test_data:
        file_name = item["file_name"]
        patch = item["patch"]
        if (item["commit_id"],item["file_name"]) in finished_item:
            print(f"Skip {item['commit_id']}, {item['file_name']}")
            continue
        commit_message = item["commit_message"] if INCLUDE_MSG == "yes" else ""
        for section in item["sections"]:
            section_content = section["section"]
            inputs.append(helper.InputData(file_name, patch, section_content, commit_message))
            labels.append(section["related"])
    
    assert len(labels) == 4088, f"Get {len(labels)} labels"
    
    try:
        this_input_prompts, this_output_text, this_output_probs = helper.do_infer(inputs, BATCH_SIZE)
    except Exception as e:
        print(f"Error: {e}")
    
    input_prompts.extend(this_input_prompts)
    output_text.extend(this_output_text)
    output_probs.extend(this_output_probs)
        
    for result in output_text:
        predicts.append("yes" in result.lower())

    # accuracy = accuracy_score(labels, predicts)
    # precision = precision_score(labels, predicts)
    # recall = recall_score(labels, predicts)
    # f1 = f1_score(labels, predicts)
    # mcc = matthews_corrcoef(labels, predicts)
    # tp, fp, fn, tn = confusion_matrix(labels, predicts).ravel()
    # fpr = fp / (fp + tn + 1e-8)
    # print("=" * 20)
    # print(f"Accuracy: {accuracy}")
    # print(f"Precision: {precision}")
    # print(f"Recall: {recall}")
    # print(f"F1: {f1}")
    # print(f"MCC: {mcc}")
    # print(f"FPR: {fpr}")
    # print("=" * 20)
    
    with jsonlines.open(args.output, "w") as writer:
        for item in test_data:
            if len(output_text) < len(item["sections"]):
                logging.info("Not enough output")
                break
            for section in item["sections"]:
                section["input_prompt"] = input_prompts.pop(0)
                section["output_text"] = output_text.pop(0)
                section["predict"] = True if predicts.pop(0) else False
                section["conf"] = output_probs.pop(0)
            writer.write(item)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--data", type=str, required=True, help="Path to the data file")
    parser.add_argument("-t", "--type", type=str, required=True, help="Type of the model", choices=["local", "online"])
    parser.add_argument("-m", "--model", type=str, required=True)
    parser.add_argument("-p", "--peft", type=str, help="Path to the PEFT file")
    parser.add_argument("-u", "--url", type=str, help="URL of the model")
    parser.add_argument("-k", "--key", type=str, help="API key")
    parser.add_argument("-o", "--output", type=str, required=True, help="Path to the output file")
    args = parser.parse_args()
    main(args)