Coyoteranger's picture
Update app.py
8436a41 verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load pre-trained model and tokenizer from the checkpoint
model_name = "flutter-code-generator/flutter_codegen_model/checkpoint-1500"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Function to clean up repetitive lines in code
def clean_code_response(response):
lines = response.splitlines()
unique_lines = []
for line in lines:
if line.strip() not in unique_lines: # Avoid duplicates
unique_lines.append(line.strip())
return "\n".join(unique_lines)
# Function to generate Flutter code
def generate_flutter_code(prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k):
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
)
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_return_sequences=num_return_sequences,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
code = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
return [clean_code_response(c) for c in code]
# App Title
st.title("Flutter Code Generator")
# Default parameter values
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_MAX_LENGTH = 512
DEFAULT_NUM_RETURN_SEQUENCES = 1
DEFAULT_REPETITION_PENALTY = 1.2
DEFAULT_TOP_K = 50
# Sidebar for settings
st.sidebar.title("Generation Settings")
temperature = st.sidebar.slider(
"Temperature (randomness)",
0.1, 1.0, DEFAULT_TEMPERATURE, step=0.1,
)
top_p = st.sidebar.slider(
"Top-p (cumulative probability)",
0.1, 1.0, DEFAULT_TOP_P, step=0.1,
)
max_length = st.sidebar.slider(
"Max Output Length (tokens)",
128, 1024, DEFAULT_MAX_LENGTH, step=64,
)
num_return_sequences = st.sidebar.slider(
"Number of Outputs",
1, 5, DEFAULT_NUM_RETURN_SEQUENCES,
)
repetition_penalty = st.sidebar.slider(
"Repetition Penalty",
1.0, 2.0, DEFAULT_REPETITION_PENALTY, step=0.1,
)
top_k = st.sidebar.slider(
"Top-k (limit sampling pool)",
0, 100, DEFAULT_TOP_K,
)
# Reset to defaults button
if st.sidebar.button("Reset to Defaults"):
st.session_state.update(
{
"temperature": DEFAULT_TEMPERATURE,
"top_p": DEFAULT_TOP_P,
"max_length": DEFAULT_MAX_LENGTH,
"num_return_sequences": DEFAULT_NUM_RETURN_SEQUENCES,
"repetition_penalty": DEFAULT_REPETITION_PENALTY,
"top_k": DEFAULT_TOP_K,
}
)
# Input Section
user_input = st.text_area(
"Enter your prompt (e.g., 'Create a responsive login screen'):",
max_chars=200,
)
# Output Section
if st.button("Generate Code"):
if user_input.strip():
prompt = f"{user_input.strip()}"
generated_code = generate_flutter_code(
prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k
)
for i, code in enumerate(generated_code, start=1):
st.subheader(f"Output {i}")
st.code(code, language="dart")
else:
st.error("Please enter a prompt before clicking 'Generate Code'.")