import os import json import logging import gradio as gr from backend.inference import section_infer, cwe_infer, PREDEF_MODEL_MAP, LOCAL_MODEL_PEFT_MAP, PREDEF_CWE_MODEL APP_TITLE = "PATCHOULI" STYLE_APP_TITLE = '
' + \ 'PATCH ' + \ 'Observing ' + \ 'and ' + \ 'Untangling ' + \ 'Engine' + \ '
' # from 0.00 to 1.00, 41 colors NONVUL_GRADIENT_COLORS = ["#d3f8d6", "#d3f8d6", "#d0f8d3", "#ccf7d0", "#c9f7cd", "#c6f6cb", "#c2f6c8", "#bff5c5", "#bcf5c2", "#b8f4bf", "#b5f4bc", "#b1f3ba", "#aef2b7", "#aaf2b4", "#a7f1b1", "#a3f1ae", "#9ff0ab", "#9cf0a9", "#98efa6", "#94efa3", "#90eea0", "#8ced9d", "#88ed9a", "#84ec98", "#80ec95", "#7ceb92", "#78ea8f", "#73ea8c", "#6fe989", "#6ae886", "#65e883", "#60e781", "#5ae67e", "#55e67b", "#4fe578", "#48e475", "#41e472", "#39e36f", "#30e26c", "#25e269", "#14e166" ] # from 0.00 to 1.00, 41 colors VUL_GRADIENT_COLORS = ["#d3f8d6", "#fdcfc9", "#fdccc5", "#fcc9c2", "#fcc5bf", "#fcc2bb", "#fbbfb8", "#fbbcb4", "#fab9b1", "#fab5ad", "#f9b2aa", "#f8afa7", "#f8aca3", "#f7a8a0", "#f7a59c", "#f6a299", "#f59f96", "#f59c92", "#f4988f", "#f3958c", "#f29288", "#f18e85", "#f18b82", "#f0887f", "#ef847c", "#ee8178", "#ed7e75", "#ec7a72", "#eb776f", "#ea736c", "#e97068", "#e86c65", "#e76962", "#e6655f", "#e5615c", "#e45e59", "#e35a56", "#e25653", "#e05250", "#df4e4d", "#de4a4a" ] logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.getLogger("httpx").setLevel(logging.WARNING) def generate_color_map(): color_map = {} for i in range(0, 101): color_map[f"non-vul-fixing: {i/100:0.2f}"] = NONVUL_GRADIENT_COLORS[int(i * 0.4)] color_map[f"vul-fixing: {i/100:0.2f}"] = VUL_GRADIENT_COLORS[int(i * 0.4)] return color_map def on_submit(diff_code, patch_message, cwe_model, section_model_type, progress = gr.Progress(track_tqdm=True), *model_config): if diff_code == "": return gr.skip(), gr.skip(), gr.skip() try: section_results = section_infer(diff_code, patch_message, section_model_type, *model_config) except Exception as e: raise gr.Error(f"Error: {str(e)}") vul_cnt = 0 for file_results in section_results.values(): for item in file_results: if item["predict"] == 1: vul_cnt += 1 label_text = f"Vul-fixing patch" if vul_cnt > 0 \ else f"Non-vul-fixing patch" color = "#de4a4a" if vul_cnt > 0 else "#14e166" patch_category_label = gr.Label(value = label_text, color = color) if cwe_model == "": cwe_cls_result = "No model selected" elif vul_cnt == 0: cwe_cls_result = "No vulnerability found" else: cwe_cls_result = cwe_infer(diff_code, patch_message, cwe_model) return patch_category_label, section_results, cwe_cls_result with gr.Blocks(title = APP_TITLE, fill_width=True) as demo: section_results_state = gr.State({}) cls_results_state = gr.State({}) title = gr.HTML(STYLE_APP_TITLE) with gr.Row() as main_block: with gr.Column(scale=1) as input_block: diff_codebox = gr.Code(label="Input git diff here", max_lines=10) with gr.Accordion("Patch message (optional)", open=False): message_textbox = gr.Textbox(label="Patch message", placeholder="Enter patch message here", container=False, lines=2, max_lines=5) cwe_model_selector = gr.Dropdown(PREDEF_CWE_MODEL, label="Select vulnerability type classifier", allow_custom_value=True) with gr.Tabs(selected=0) as model_type_tabs: MODEL_TYPE_NAMES = list(PREDEF_MODEL_MAP.keys()) with gr.Tab(MODEL_TYPE_NAMES[0]) as local_llm_tab: local_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[0]], label="Select model", allow_custom_value=True) local_peft_selector = gr.Dropdown(LOCAL_MODEL_PEFT_MAP[local_model_selector.value], label="Select PEFT model (optional)", allow_custom_value=True) local_submit_btn = gr.Button("Run", variant="primary") with gr.Tab(MODEL_TYPE_NAMES[1]) as online_llm_tab: online_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[1]], label="Select model", allow_custom_value=True) online_api_url_textbox = gr.Textbox(label="API URL") online_api_key_textbox = gr.Textbox(label="API Key", placeholder="We won't store your API key", value=os.getenv("ONLINE_API_KEY"), type="password") online_submit_btn = gr.Button("Run", variant="primary") section_model_type = gr.State(model_type_tabs.children[0].label) with gr.Accordion("Load examples", open=False): with open("./backend/examples.json", "r") as f: examples = json.load(f) gr.Button("Load example 1", size='sm').click(lambda : examples[0], outputs=[diff_codebox, message_textbox]) gr.Button("Load example 2", size='sm').click(lambda : examples[1], outputs=[diff_codebox, message_textbox]) gr.Button("Load example 3", size='sm').click(lambda : examples[2], outputs=[diff_codebox, message_textbox]) with gr.Column(scale=2) as section_result_block: @gr.render(inputs=section_results_state, triggers=[section_results_state.change, demo.load]) def display_result(section_results): if not section_results or len(section_results) == 0: with gr.Tab("File tabs"): gr.Markdown("No results") else: for file_name, file_results in section_results.items(): with gr.Tab(file_name) as file_tab: highlited_results = [] full_color_map = generate_color_map() this_color_map = {} for item in file_results: predict_result = {-1: 'error', 0: 'non-vul-fixing', 1: 'vul-fixing'} text_label = f"{predict_result[item['predict']]}: {item['conf']:0.2f}" this_color_map[text_label] = full_color_map[text_label] highlited_results.append(( item["section"], text_label )) gr.HighlightedText( highlited_results, label="Results", color_map=this_color_map ) with gr.Column(scale=1) as result_block: patch_category_label = gr.Label(value = "No results", label = "Result of the whole patch") def update_vul_type_label(cls_results): return gr.Label(cls_results) vul_type_label = gr.Label(update_vul_type_label, label = "Possible fixed vulnerability type", inputs = [cls_results_state]) def update_model_type_state(evt: gr.SelectData): return evt.value model_type_tabs.select(update_model_type_state, outputs = [section_model_type]) def update_support_peft(base_model): return gr.Dropdown(LOCAL_MODEL_PEFT_MAP[base_model], value = LOCAL_MODEL_PEFT_MAP[base_model][0]) local_model_selector.change(update_support_peft, inputs=[local_model_selector], outputs = [local_peft_selector]) local_submit_btn.click(fn = on_submit, inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, local_model_selector, local_peft_selector], outputs = [patch_category_label, section_results_state, cls_results_state]) online_submit_btn.click(fn = on_submit, inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, online_model_selector, online_api_url_textbox, online_api_key_textbox], outputs = [patch_category_label, section_results_state, cls_results_state]) demo.launch()