File size: 2,007 Bytes
5ef742a
 
 
 
efc64b8
5ef742a
 
 
 
 
 
 
 
 
efc64b8
5ef742a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efc64b8
5ef742a
 
 
 
 
 
 
 
 
 
 
 
 
efc64b8
 
 
 
 
 
 
5ef742a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import difflib
import re
from utils import verify_diff, apply_diff_from_output

commit_message_per_brush = {
    "Annotate Type": "annotate type to the variables.",
    "Reformat" : "Reformat the code using pep8",
    "Add Docstrings" : "Add docstrings to all the functions",
    "Add Comments" : "Add inline comments to all the functions",
}


def load_model_and_tokenizer(model_name:str="CarperAI/diff-codegen-350M-v2"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return tokenizer, model

def make_prompt(code:str,task):
    filename = "input.py"
    prompt = f"<NME>main.py<BEF>{code}<MSG>{commit_message_per_brush[task]}."
    return prompt


def generate_diff(code:str):
    input_ids = tokenizer.encode(code, return_tensors='pt')
    outputs = model.generate(input_ids, max_length=64,temperature=0.8,top_p=0.85)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


def postprocess_output(generated_output:str):
    return verify_diff(generated_output)

st.title("Code Brush")
st.write("A tool to brush up your code")



tokenizer,model = load_model_and_tokenizer()
with st.form("my_form"):
    text = st.text_area("Enter your code here", height=150, value="def greet(input_name):\n    return f'Hello, {input_name}'" )
    brush_type = st.selectbox("Brush Type", ["Annotate Type", "Reformat", "Add Docstrings", "Add Comments"])
    submit_button = st.form_submit_button("Submit")
    if submit_button:
        st.write("## Diff:")
        generate_diff = generate_diff(make_prompt(text,brush_type))
        after_file = apply_diff_from_output(generate_diff)
        generate_diff_processed = postprocess_output(generate_diff)
        st.write(after_file)
        st.write(generate_diff_processed)
        #st.text_area(generate_diff_processed)
        #st.text_area(generate_diff, height=150, value=generate_diff)