Spaces:
Build error
Build error
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load pre-trained model and tokenizer from the checkpoint | |
model_name = "flutter-code-generator/flutter_codegen_model/checkpoint-1500" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Function to clean up repetitive lines in code | |
def clean_code_response(response): | |
lines = response.splitlines() | |
unique_lines = [] | |
for line in lines: | |
if line.strip() not in unique_lines: # Avoid duplicates | |
unique_lines.append(line.strip()) | |
return "\n".join(unique_lines) | |
# Function to generate Flutter code | |
def generate_flutter_code(prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k): | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
) | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
) | |
code = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
return [clean_code_response(c) for c in code] | |
# App Title | |
st.title("Flutter Code Generator") | |
# Default parameter values | |
DEFAULT_TEMPERATURE = 0.7 | |
DEFAULT_TOP_P = 0.9 | |
DEFAULT_MAX_LENGTH = 512 | |
DEFAULT_NUM_RETURN_SEQUENCES = 1 | |
DEFAULT_REPETITION_PENALTY = 1.2 | |
DEFAULT_TOP_K = 50 | |
# Sidebar for settings | |
st.sidebar.title("Generation Settings") | |
temperature = st.sidebar.slider( | |
"Temperature (randomness)", | |
0.1, 1.0, DEFAULT_TEMPERATURE, step=0.1, | |
) | |
top_p = st.sidebar.slider( | |
"Top-p (cumulative probability)", | |
0.1, 1.0, DEFAULT_TOP_P, step=0.1, | |
) | |
max_length = st.sidebar.slider( | |
"Max Output Length (tokens)", | |
128, 1024, DEFAULT_MAX_LENGTH, step=64, | |
) | |
num_return_sequences = st.sidebar.slider( | |
"Number of Outputs", | |
1, 5, DEFAULT_NUM_RETURN_SEQUENCES, | |
) | |
repetition_penalty = st.sidebar.slider( | |
"Repetition Penalty", | |
1.0, 2.0, DEFAULT_REPETITION_PENALTY, step=0.1, | |
) | |
top_k = st.sidebar.slider( | |
"Top-k (limit sampling pool)", | |
0, 100, DEFAULT_TOP_K, | |
) | |
# Reset to defaults button | |
if st.sidebar.button("Reset to Defaults"): | |
st.session_state.update( | |
{ | |
"temperature": DEFAULT_TEMPERATURE, | |
"top_p": DEFAULT_TOP_P, | |
"max_length": DEFAULT_MAX_LENGTH, | |
"num_return_sequences": DEFAULT_NUM_RETURN_SEQUENCES, | |
"repetition_penalty": DEFAULT_REPETITION_PENALTY, | |
"top_k": DEFAULT_TOP_K, | |
} | |
) | |
# Input Section | |
user_input = st.text_area( | |
"Enter your prompt (e.g., 'Create a responsive login screen'):", | |
max_chars=200, | |
) | |
# Output Section | |
if st.button("Generate Code"): | |
if user_input.strip(): | |
prompt = f"{user_input.strip()}" | |
generated_code = generate_flutter_code( | |
prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k | |
) | |
for i, code in enumerate(generated_code, start=1): | |
st.subheader(f"Output {i}") | |
st.code(code, language="dart") | |
else: | |
st.error("Please enter a prompt before clicking 'Generate Code'.") | |