Coyoteranger commited on
Commit
8436a41
·
verified ·
1 Parent(s): 2c85281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -115
app.py CHANGED
@@ -1,115 +1,115 @@
1
- import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
-
4
- # Load pre-trained model and tokenizer from the checkpoint
5
- model_name = "./flutter_codegen_model/checkpoint-1500"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- tokenizer.pad_token = tokenizer.eos_token
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
-
10
- # Function to clean up repetitive lines in code
11
- def clean_code_response(response):
12
- lines = response.splitlines()
13
- unique_lines = []
14
- for line in lines:
15
- if line.strip() not in unique_lines: # Avoid duplicates
16
- unique_lines.append(line.strip())
17
- return "\n".join(unique_lines)
18
-
19
- # Function to generate Flutter code
20
- def generate_flutter_code(prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k):
21
- inputs = tokenizer(
22
- prompt,
23
- return_tensors="pt",
24
- padding=True,
25
- truncation=True,
26
- )
27
- outputs = model.generate(
28
- inputs["input_ids"],
29
- max_length=max_length,
30
- num_return_sequences=num_return_sequences,
31
- temperature=temperature,
32
- top_p=top_p,
33
- top_k=top_k,
34
- repetition_penalty=repetition_penalty,
35
- do_sample=True,
36
- pad_token_id=tokenizer.pad_token_id,
37
- )
38
- code = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
39
- return [clean_code_response(c) for c in code]
40
-
41
- # App Title
42
- st.title("Flutter Code Generator")
43
-
44
- # Default parameter values
45
- DEFAULT_TEMPERATURE = 0.7
46
- DEFAULT_TOP_P = 0.9
47
- DEFAULT_MAX_LENGTH = 512
48
- DEFAULT_NUM_RETURN_SEQUENCES = 1
49
- DEFAULT_REPETITION_PENALTY = 1.2
50
- DEFAULT_TOP_K = 50
51
-
52
- # Sidebar for settings
53
- st.sidebar.title("Generation Settings")
54
-
55
- temperature = st.sidebar.slider(
56
- "Temperature (randomness)",
57
- 0.1, 1.0, DEFAULT_TEMPERATURE, step=0.1,
58
- )
59
-
60
- top_p = st.sidebar.slider(
61
- "Top-p (cumulative probability)",
62
- 0.1, 1.0, DEFAULT_TOP_P, step=0.1,
63
- )
64
-
65
- max_length = st.sidebar.slider(
66
- "Max Output Length (tokens)",
67
- 128, 1024, DEFAULT_MAX_LENGTH, step=64,
68
- )
69
-
70
- num_return_sequences = st.sidebar.slider(
71
- "Number of Outputs",
72
- 1, 5, DEFAULT_NUM_RETURN_SEQUENCES,
73
- )
74
-
75
- repetition_penalty = st.sidebar.slider(
76
- "Repetition Penalty",
77
- 1.0, 2.0, DEFAULT_REPETITION_PENALTY, step=0.1,
78
- )
79
-
80
- top_k = st.sidebar.slider(
81
- "Top-k (limit sampling pool)",
82
- 0, 100, DEFAULT_TOP_K,
83
- )
84
-
85
- # Reset to defaults button
86
- if st.sidebar.button("Reset to Defaults"):
87
- st.session_state.update(
88
- {
89
- "temperature": DEFAULT_TEMPERATURE,
90
- "top_p": DEFAULT_TOP_P,
91
- "max_length": DEFAULT_MAX_LENGTH,
92
- "num_return_sequences": DEFAULT_NUM_RETURN_SEQUENCES,
93
- "repetition_penalty": DEFAULT_REPETITION_PENALTY,
94
- "top_k": DEFAULT_TOP_K,
95
- }
96
- )
97
-
98
- # Input Section
99
- user_input = st.text_area(
100
- "Enter your prompt (e.g., 'Create a responsive login screen'):",
101
- max_chars=200,
102
- )
103
-
104
- # Output Section
105
- if st.button("Generate Code"):
106
- if user_input.strip():
107
- prompt = f"{user_input.strip()}"
108
- generated_code = generate_flutter_code(
109
- prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k
110
- )
111
- for i, code in enumerate(generated_code, start=1):
112
- st.subheader(f"Output {i}")
113
- st.code(code, language="dart")
114
- else:
115
- st.error("Please enter a prompt before clicking 'Generate Code'.")
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ # Load pre-trained model and tokenizer from the checkpoint
5
+ model_name = "flutter-code-generator/flutter_codegen_model/checkpoint-1500"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+
10
+ # Function to clean up repetitive lines in code
11
+ def clean_code_response(response):
12
+ lines = response.splitlines()
13
+ unique_lines = []
14
+ for line in lines:
15
+ if line.strip() not in unique_lines: # Avoid duplicates
16
+ unique_lines.append(line.strip())
17
+ return "\n".join(unique_lines)
18
+
19
+ # Function to generate Flutter code
20
+ def generate_flutter_code(prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k):
21
+ inputs = tokenizer(
22
+ prompt,
23
+ return_tensors="pt",
24
+ padding=True,
25
+ truncation=True,
26
+ )
27
+ outputs = model.generate(
28
+ inputs["input_ids"],
29
+ max_length=max_length,
30
+ num_return_sequences=num_return_sequences,
31
+ temperature=temperature,
32
+ top_p=top_p,
33
+ top_k=top_k,
34
+ repetition_penalty=repetition_penalty,
35
+ do_sample=True,
36
+ pad_token_id=tokenizer.pad_token_id,
37
+ )
38
+ code = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
39
+ return [clean_code_response(c) for c in code]
40
+
41
+ # App Title
42
+ st.title("Flutter Code Generator")
43
+
44
+ # Default parameter values
45
+ DEFAULT_TEMPERATURE = 0.7
46
+ DEFAULT_TOP_P = 0.9
47
+ DEFAULT_MAX_LENGTH = 512
48
+ DEFAULT_NUM_RETURN_SEQUENCES = 1
49
+ DEFAULT_REPETITION_PENALTY = 1.2
50
+ DEFAULT_TOP_K = 50
51
+
52
+ # Sidebar for settings
53
+ st.sidebar.title("Generation Settings")
54
+
55
+ temperature = st.sidebar.slider(
56
+ "Temperature (randomness)",
57
+ 0.1, 1.0, DEFAULT_TEMPERATURE, step=0.1,
58
+ )
59
+
60
+ top_p = st.sidebar.slider(
61
+ "Top-p (cumulative probability)",
62
+ 0.1, 1.0, DEFAULT_TOP_P, step=0.1,
63
+ )
64
+
65
+ max_length = st.sidebar.slider(
66
+ "Max Output Length (tokens)",
67
+ 128, 1024, DEFAULT_MAX_LENGTH, step=64,
68
+ )
69
+
70
+ num_return_sequences = st.sidebar.slider(
71
+ "Number of Outputs",
72
+ 1, 5, DEFAULT_NUM_RETURN_SEQUENCES,
73
+ )
74
+
75
+ repetition_penalty = st.sidebar.slider(
76
+ "Repetition Penalty",
77
+ 1.0, 2.0, DEFAULT_REPETITION_PENALTY, step=0.1,
78
+ )
79
+
80
+ top_k = st.sidebar.slider(
81
+ "Top-k (limit sampling pool)",
82
+ 0, 100, DEFAULT_TOP_K,
83
+ )
84
+
85
+ # Reset to defaults button
86
+ if st.sidebar.button("Reset to Defaults"):
87
+ st.session_state.update(
88
+ {
89
+ "temperature": DEFAULT_TEMPERATURE,
90
+ "top_p": DEFAULT_TOP_P,
91
+ "max_length": DEFAULT_MAX_LENGTH,
92
+ "num_return_sequences": DEFAULT_NUM_RETURN_SEQUENCES,
93
+ "repetition_penalty": DEFAULT_REPETITION_PENALTY,
94
+ "top_k": DEFAULT_TOP_K,
95
+ }
96
+ )
97
+
98
+ # Input Section
99
+ user_input = st.text_area(
100
+ "Enter your prompt (e.g., 'Create a responsive login screen'):",
101
+ max_chars=200,
102
+ )
103
+
104
+ # Output Section
105
+ if st.button("Generate Code"):
106
+ if user_input.strip():
107
+ prompt = f"{user_input.strip()}"
108
+ generated_code = generate_flutter_code(
109
+ prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k
110
+ )
111
+ for i, code in enumerate(generated_code, start=1):
112
+ st.subheader(f"Output {i}")
113
+ st.code(code, language="dart")
114
+ else:
115
+ st.error("Please enter a prompt before clicking 'Generate Code'.")