import os
import json
import logging
import gradio as gr
from backend.inference import section_infer, cwe_infer, PREDEF_MODEL_MAP, LOCAL_MODEL_PEFT_MAP, PREDEF_CWE_MODEL
APP_TITLE = "PATCHOULI"
STYLE_APP_TITLE = '
' + \
'PATCH ' + \
'Observing ' + \
'and ' + \
'Untangling ' + \
'Engine' + \
'
'
# from 0.00 to 1.00, 41 colors
NONVUL_GRADIENT_COLORS = ["#d3f8d6",
"#d3f8d6", "#d0f8d3", "#ccf7d0", "#c9f7cd", "#c6f6cb", "#c2f6c8", "#bff5c5", "#bcf5c2", "#b8f4bf", "#b5f4bc",
"#b1f3ba", "#aef2b7", "#aaf2b4", "#a7f1b1", "#a3f1ae", "#9ff0ab", "#9cf0a9", "#98efa6", "#94efa3", "#90eea0",
"#8ced9d", "#88ed9a", "#84ec98", "#80ec95", "#7ceb92", "#78ea8f", "#73ea8c", "#6fe989", "#6ae886", "#65e883",
"#60e781", "#5ae67e", "#55e67b", "#4fe578", "#48e475", "#41e472", "#39e36f", "#30e26c", "#25e269", "#14e166"
]
# from 0.00 to 1.00, 41 colors
VUL_GRADIENT_COLORS = ["#d3f8d6",
"#fdcfc9", "#fdccc5", "#fcc9c2", "#fcc5bf", "#fcc2bb", "#fbbfb8", "#fbbcb4", "#fab9b1", "#fab5ad", "#f9b2aa",
"#f8afa7", "#f8aca3", "#f7a8a0", "#f7a59c", "#f6a299", "#f59f96", "#f59c92", "#f4988f", "#f3958c", "#f29288",
"#f18e85", "#f18b82", "#f0887f", "#ef847c", "#ee8178", "#ed7e75", "#ec7a72", "#eb776f", "#ea736c", "#e97068",
"#e86c65", "#e76962", "#e6655f", "#e5615c", "#e45e59", "#e35a56", "#e25653", "#e05250", "#df4e4d", "#de4a4a"
]
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.getLogger("httpx").setLevel(logging.WARNING)
def generate_color_map():
color_map = {}
for i in range(0, 101):
color_map[f"non-vul-fixing: {i/100:0.2f}"] = NONVUL_GRADIENT_COLORS[int(i * 0.4)]
color_map[f"vul-fixing: {i/100:0.2f}"] = VUL_GRADIENT_COLORS[int(i * 0.4)]
return color_map
def on_submit(diff_code, patch_message, cwe_model, section_model_type, progress = gr.Progress(track_tqdm=True), *model_config):
if diff_code == "":
return gr.skip(), gr.skip(), gr.skip()
try:
section_results = section_infer(diff_code, patch_message, section_model_type, *model_config)
except Exception as e:
raise gr.Error(f"Error: {str(e)}")
vul_cnt = 0
for file_results in section_results.values():
for item in file_results:
if item["predict"] == 1:
vul_cnt += 1
label_text = f"Vul-fixing patch" if vul_cnt > 0 \
else f"Non-vul-fixing patch"
color = "#de4a4a" if vul_cnt > 0 else "#14e166"
patch_category_label = gr.Label(value = label_text, color = color)
if cwe_model == "":
cwe_cls_result = "No model selected"
elif vul_cnt == 0:
cwe_cls_result = "No vulnerability found"
else:
cwe_cls_result = cwe_infer(diff_code, patch_message, cwe_model)
return patch_category_label, section_results, cwe_cls_result
with gr.Blocks(title = APP_TITLE, fill_width=True) as demo:
section_results_state = gr.State({})
cls_results_state = gr.State({})
title = gr.HTML(STYLE_APP_TITLE)
with gr.Row() as main_block:
with gr.Column(scale=1) as input_block:
diff_codebox = gr.Code(label="Input git diff here", max_lines=10)
with gr.Accordion("Patch message (optional)", open=False):
message_textbox = gr.Textbox(label="Patch message", placeholder="Enter patch message here", container=False, lines=2, max_lines=5)
cwe_model_selector = gr.Dropdown(PREDEF_CWE_MODEL, label="Select vulnerability type classifier", allow_custom_value=True)
with gr.Tabs(selected=0) as model_type_tabs:
MODEL_TYPE_NAMES = list(PREDEF_MODEL_MAP.keys())
with gr.Tab(MODEL_TYPE_NAMES[0]) as local_llm_tab:
local_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[0]], label="Select model", allow_custom_value=True)
local_peft_selector = gr.Dropdown(LOCAL_MODEL_PEFT_MAP[local_model_selector.value], label="Select PEFT model (optional)", allow_custom_value=True)
local_submit_btn = gr.Button("Run", variant="primary")
with gr.Tab(MODEL_TYPE_NAMES[1]) as online_llm_tab:
online_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[1]], label="Select model", allow_custom_value=True)
online_api_url_textbox = gr.Textbox(label="API URL")
online_api_key_textbox = gr.Textbox(label="API Key", placeholder="We won't store your API key", value=os.getenv("ONLINE_API_KEY"), type="password")
online_submit_btn = gr.Button("Run", variant="primary")
section_model_type = gr.State(model_type_tabs.children[0].label)
with gr.Accordion("Load examples", open=False):
with open("./backend/examples.json", "r") as f:
examples = json.load(f)
gr.Button("Load example 1", size='sm').click(lambda : examples[0], outputs=[diff_codebox, message_textbox])
gr.Button("Load example 2", size='sm').click(lambda : examples[1], outputs=[diff_codebox, message_textbox])
gr.Button("Load example 3", size='sm').click(lambda : examples[2], outputs=[diff_codebox, message_textbox])
with gr.Column(scale=2) as section_result_block:
@gr.render(inputs=section_results_state, triggers=[section_results_state.change, demo.load])
def display_result(section_results):
if not section_results or len(section_results) == 0:
with gr.Tab("File tabs"):
gr.Markdown("No results")
else:
for file_name, file_results in section_results.items():
with gr.Tab(file_name) as file_tab:
highlited_results = []
full_color_map = generate_color_map()
this_color_map = {}
for item in file_results:
predict_result = {-1: 'error', 0: 'non-vul-fixing', 1: 'vul-fixing'}
text_label = f"{predict_result[item['predict']]}: {item['conf']:0.2f}"
this_color_map[text_label] = full_color_map[text_label]
highlited_results.append((
item["section"],
text_label
))
gr.HighlightedText(
highlited_results,
label="Results",
color_map=this_color_map
)
with gr.Column(scale=1) as result_block:
patch_category_label = gr.Label(value = "No results", label = "Result of the whole patch")
def update_vul_type_label(cls_results):
return gr.Label(cls_results)
vul_type_label = gr.Label(update_vul_type_label, label = "Possible fixed vulnerability type", inputs = [cls_results_state])
def update_model_type_state(evt: gr.SelectData):
return evt.value
model_type_tabs.select(update_model_type_state, outputs = [section_model_type])
def update_support_peft(base_model):
return gr.Dropdown(LOCAL_MODEL_PEFT_MAP[base_model], value = LOCAL_MODEL_PEFT_MAP[base_model][0])
local_model_selector.change(update_support_peft, inputs=[local_model_selector], outputs = [local_peft_selector])
local_submit_btn.click(fn = on_submit,
inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, local_model_selector, local_peft_selector],
outputs = [patch_category_label, section_results_state, cls_results_state])
online_submit_btn.click(fn = on_submit,
inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, online_model_selector, online_api_url_textbox, online_api_key_textbox],
outputs = [patch_category_label, section_results_state, cls_results_state])
demo.launch()