File size: 3,454 Bytes
8436a41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load pre-trained model and tokenizer from the checkpoint
model_name = "flutter-code-generator/flutter_codegen_model/checkpoint-1500"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Function to clean up repetitive lines in code
def clean_code_response(response):
    lines = response.splitlines()
    unique_lines = []
    for line in lines:
        if line.strip() not in unique_lines:  # Avoid duplicates
            unique_lines.append(line.strip())
    return "\n".join(unique_lines)

# Function to generate Flutter code
def generate_flutter_code(prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    outputs = model.generate(
        inputs["input_ids"],
        max_length=max_length,
        num_return_sequences=num_return_sequences,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )
    code = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    return [clean_code_response(c) for c in code]

# App Title
st.title("Flutter Code Generator")

# Default parameter values
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_MAX_LENGTH = 512
DEFAULT_NUM_RETURN_SEQUENCES = 1
DEFAULT_REPETITION_PENALTY = 1.2
DEFAULT_TOP_K = 50

# Sidebar for settings
st.sidebar.title("Generation Settings")

temperature = st.sidebar.slider(
    "Temperature (randomness)",
    0.1, 1.0, DEFAULT_TEMPERATURE, step=0.1,
)

top_p = st.sidebar.slider(
    "Top-p (cumulative probability)",
    0.1, 1.0, DEFAULT_TOP_P, step=0.1,
)

max_length = st.sidebar.slider(
    "Max Output Length (tokens)",
    128, 1024, DEFAULT_MAX_LENGTH, step=64,
)

num_return_sequences = st.sidebar.slider(
    "Number of Outputs",
    1, 5, DEFAULT_NUM_RETURN_SEQUENCES,
)

repetition_penalty = st.sidebar.slider(
    "Repetition Penalty",
    1.0, 2.0, DEFAULT_REPETITION_PENALTY, step=0.1,
)

top_k = st.sidebar.slider(
    "Top-k (limit sampling pool)",
    0, 100, DEFAULT_TOP_K,
)

# Reset to defaults button
if st.sidebar.button("Reset to Defaults"):
    st.session_state.update(
        {
            "temperature": DEFAULT_TEMPERATURE,
            "top_p": DEFAULT_TOP_P,
            "max_length": DEFAULT_MAX_LENGTH,
            "num_return_sequences": DEFAULT_NUM_RETURN_SEQUENCES,
            "repetition_penalty": DEFAULT_REPETITION_PENALTY,
            "top_k": DEFAULT_TOP_K,
        }
    )

# Input Section
user_input = st.text_area(
    "Enter your prompt (e.g., 'Create a responsive login screen'):",
    max_chars=200,
)

# Output Section
if st.button("Generate Code"):
    if user_input.strip():
        prompt = f"{user_input.strip()}"
        generated_code = generate_flutter_code(
            prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k
        )
        for i, code in enumerate(generated_code, start=1):
            st.subheader(f"Output {i}")
            st.code(code, language="dart")
    else:
        st.error("Please enter a prompt before clicking 'Generate Code'.")