acecalisto3 commited on
Commit
6ec6288
·
verified ·
1 Parent(s): 4db57ea

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +271 -194
agent.py CHANGED
@@ -1,207 +1,284 @@
1
- from typing import List, Dict, Optional
2
- from custom_types import (
3
- Code,
4
- Prompt,
5
- AppType,
6
- File,
7
- Space,
8
- Tutorial,
9
- App,
10
- WebApp,
11
- GradioApp,
12
- StreamlitApp,
13
- ReactApp,
14
- Code,
15
- )
16
- from prompts import (
17
- createLlamaPrompt,
18
- createSpace,
19
- isPythonOrGradioAppPrompt,
20
- isReactAppPrompt,
21
- isStreamlitAppPrompt,
22
- getWebApp,
23
- getGradioApp,
24
- getReactApp,
25
- getStreamlitApp,
26
- parseTutorial,
27
- generateFiles,
28
  )
29
- from huggingface_hub import InferenceClient
30
 
31
- class Agent:
32
- def __init__(self, prompts: Dict[str, any]):
33
- self.prompts = prompts
34
- self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
35
 
36
- def process(self, user_input: str) -> str:
37
- """ Processes the user's input and generates code. """
38
- # Parse the user's input
39
- app_type, app_name, app_description, app_features, app_dependencies, app_space, app_tutorial = self.parse_input(user_input)
40
 
41
- # Generate a prompt for the Llama model
42
- prompt = self.prompts["createLlamaPrompt"](
43
- app_type, app_name, app_description, app_features, app_dependencies, app_space, app_tutorial
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Generate code using the Llama model
47
- code = self.generate_code(prompt)
48
 
49
- # Generate files for the application
50
- files = self.prompts["generateFiles"](
51
- app_type, app_name, app_description, app_features, app_dependencies, app_space, app_tutorial
52
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Return the generated code and files
55
- return f"Code: {code}\nFiles: {files}"
56
-
57
- def parse_input(self, user_input: str) -> tuple:
58
- """ Parses the user's input and extracts the relevant information. """
59
- # Extract the app type
60
- app_type = self.extract_app_type(user_input)
61
- # Extract the app name
62
- app_name = self.extract_app_name(user_input)
63
- # Extract the app description
64
- app_description = self.extract_app_description(user_input)
65
- # Extract the app features
66
- app_features = self.extract_app_features(user_input)
67
- # Extract the app dependencies
68
- app_dependencies = self.extract_app_dependencies(user_input)
69
- # Extract the app space
70
- app_space = self.extract_app_space(user_input)
71
- # Extract the app tutorial
72
- app_tutorial = self.extract_app_tutorial(user_input)
73
-
74
- return app_type, app_name, app_description, app_features, app_dependencies, app_space, app_tutorial
75
-
76
- def extract_app_type(self, user_input: str) -> AppType:
77
- """ Extracts the app type from the user's input. """
78
- # Check if the user specified a specific app type
79
- if "web app" in user_input:
80
- return AppType.WEB_APP
81
- elif "gradio app" in user_input:
82
- return AppType.GRADIO_APP
83
- elif "streamlit app" in user_input:
84
- return AppType.STREAMLIT_APP
85
- elif "react app" in user_input:
86
- return AppType.REACT_APP
87
- # Otherwise, assume the user wants a web app
88
- return AppType.WEB_APP
89
-
90
- def extract_app_name(self, user_input: str) -> str:
91
- """ Extracts the app name from the user's input. """
92
- # Find the substring "app name is:"
93
- start_index = user_input.find("app name is:") + len("app name is:")
94
- # Find the end of the app name
95
- end_index = user_input.find(".", start_index)
96
- # Extract the app name
97
- app_name = user_input[start_index:end_index].strip()
98
- return app_name
99
-
100
- def extract_app_description(self, user_input: str) -> str:
101
- """ Extracts the app description from the user's input. """
102
- # Find the substring "app description is:"
103
- start_index = user_input.find("app description is:") + len("app description is:")
104
- # Find the end of the app description
105
- end_index = user_input.find(".", start_index)
106
- # Extract the app description
107
- app_description = user_input[start_index:end_index].strip()
108
- return app_description
109
-
110
- def extract_app_features(self, user_input: str) -> List[str]:
111
- """ Extracts the app features from the user's input. """
112
- # Find the substring "app features are:"
113
- start_index = user_input.find("app features are:") + len("app features are:")
114
- # Find the end of the app features
115
- end_index = user_input.find(".", start_index)
116
- # Extract the app features
117
- app_features_str = user_input[start_index:end_index].strip()
118
- # Split the app features string into a list
119
- app_features = app_features_str.split(", ")
120
- return app_features
121
-
122
- def extract_app_dependencies(self, user_input: str) -> List[str]:
123
- """ Extracts the app dependencies from the user's input. """
124
- # Find the substring "app dependencies are:"
125
- start_index = user_input.find("app dependencies are:") + len("app dependencies are:")
126
- # Find the end of the app dependencies
127
- end_index = user_input.find(".", start_index)
128
- # Extract the app dependencies
129
- app_dependencies_str = user_input[start_index:end_index].strip()
130
- # Split the app dependencies string into a list
131
- app_dependencies = app_dependencies_str.split(", ")
132
- return app_dependencies
133
-
134
- def extract_app_space(self, user_input: str) -> Optional[Space]:
135
- """ Extracts the app space from the user's input. """
136
- # Find the substring "app space is:"
137
- start_index = user_input.find("app space is:") + len("app space is:")
138
- # Find the end of the app space
139
- end_index = user_input.find(".", start_index)
140
- # Extract the app space
141
- app_space_str = user_input[start_index:end_index].strip()
142
- # Create a Space object
143
- app_space = Space(space=app_space_str)
144
- return app_space
145
-
146
- def extract_app_tutorial(self, user_input: str) -> Optional[Tutorial]:
147
- """ Extracts the app tutorial from the user's input. """
148
- # Find the substring "app tutorial is:"
149
- start_index = user_input.find("app tutorial is:") + len("app tutorial is:")
150
- # Find the end of the app tutorial
151
- end_index = user_input.find(".", start_index)
152
- # Extract the app tutorial
153
- app_tutorial_str = user_input[start_index:end_index].strip()
154
- # Create a Tutorial object
155
- app_tutorial = Tutorial(tutorial=app_tutorial_str)
156
- return app_tutorial
157
-
158
- def generate_code(self, prompt: Prompt) -> Code:
159
- """ Generates code using the Llama model. """
160
- # Send the prompt to the Llama model
161
- response = self.client(prompt.prompt)
162
- # Extract the generated code
163
- code = response["generated_text"]
164
- code = code.replace("```", "")
165
- code = code.replace("```", "")
166
- # Create a Code object
167
- code = Code(code=code, language="python")
168
- return code
169
-
170
- def generate_files(self, app_type: AppType, app_name: str, app_description: str, app_features: List[str], app_dependencies: List[str], app_space: Optional[Space] = None, app_tutorial: Optional[Tutorial] = None) -> List[File]:
171
- """ Generates files for the application. """
172
- # Generate files based on the app type
173
- files = self.prompts["generateFiles"](
174
- app_type, app_name, app_description, app_features, app_dependencies, app_space, app_tutorial
175
  )
176
- return files
177
-
178
- def main():
179
- """ Main function for the application. """
180
- # Create an agent
181
- agent = Agent(
182
- prompts={
183
- "createLlamaPrompt": createLlamaPrompt,
184
- "createSpace": createSpace,
185
- "isPythonOrGradioAppPrompt": isPythonOrGradioAppPrompt,
186
- "isReactAppPrompt": isReactAppPrompt,
187
- "isStreamlitAppPrompt": isStreamlitAppPrompt,
188
- "getWebApp": getWebApp,
189
- "getGradioApp": getGradioApp,
190
- "getReactApp": getReactApp,
191
- "getStreamlitApp": getStreamlitApp,
192
- "parseTutorial": parseTutorial,
193
- "generateFiles": generateFiles,
194
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Get user input
198
- user_input = input("Enter your request: ")
199
 
200
- # Process the user's input
201
- response = agent.process(user_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- # Print the response
204
- print(response)
 
 
 
 
 
 
205
 
206
- if __name__ == "__main__":
207
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+
4
+ import openai
5
+
6
+ from agent.prompts import (
7
+ ACTION_PROMPT,
8
+ ADD_PROMPT,
9
+ COMPRESS_HISTORY_PROMPT,
10
+ LOG_PROMPT,
11
+ LOG_RESPONSE,
12
+ MODIFY_PROMPT,
13
+ PREFIX,
14
+ READ_PROMPT,
15
+ TASK_PROMPT,
16
+ UNDERSTAND_TEST_RESULTS_PROMPT,
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
+ from agent.utils import parse_action, parse_file_content, read_python_module_structure
19
 
20
+ VERBOSE = False
21
+ MAX_HISTORY = 100
22
+ MODEL = "gpt-3.5-turbo" # "gpt-4"
 
23
 
 
 
 
 
24
 
25
+ def run_gpt(
26
+ prompt_template,
27
+ stop_tokens,
28
+ max_tokens,
29
+ module_summary,
30
+ purpose,
31
+ **prompt_kwargs,
32
+ ):
33
+ content = PREFIX.format(
34
+ module_summary=module_summary,
35
+ purpose=purpose,
36
+ ) + prompt_template.format(**prompt_kwargs)
37
+ if VERBOSE:
38
+ print(LOG_PROMPT.format(content))
39
+ resp = openai.ChatCompletion.create(
40
+ model=MODEL,
41
+ messages=[
42
+ {"role": "system", "content": content},
43
+ ],
44
+ temperature=0.0,
45
+ max_tokens=max_tokens,
46
+ stop=stop_tokens if stop_tokens else None,
47
+ )["choices"][0]["message"]["content"]
48
+ if VERBOSE:
49
+ print(LOG_RESPONSE.format(resp))
50
+ return resp
51
 
 
 
52
 
53
+ def compress_history(purpose, task, history, directory):
54
+ module_summary, _, _ = read_python_module_structure(directory)
55
+ resp = run_gpt(
56
+ COMPRESS_HISTORY_PROMPT,
57
+ stop_tokens=["observation:", "task:", "action:", "thought:"],
58
+ max_tokens=512,
59
+ module_summary=module_summary,
60
+ purpose=purpose,
61
+ task=task,
62
+ history=history,
63
+ )
64
+ history = "observation: {}\n".format(resp)
65
+ return history
66
+
67
+
68
+ def call_main(purpose, task, history, directory, action_input):
69
+ module_summary, _, _ = read_python_module_structure(directory)
70
+ resp = run_gpt(
71
+ ACTION_PROMPT,
72
+ stop_tokens=["observation:", "task:"],
73
+ max_tokens=256,
74
+ module_summary=module_summary,
75
+ purpose=purpose,
76
+ task=task,
77
+ history=history,
78
+ )
79
+ lines = resp.strip().strip("\n").split("\n")
80
+ for line in lines:
81
+ if line == "":
82
+ continue
83
+ if line.startswith("thought: "):
84
+ history += "{}\n".format(line)
85
+ elif line.startswith("action: "):
86
+ action_name, action_input = parse_action(line)
87
+ history += "{}\n".format(line)
88
+ return action_name, action_input, history, task
89
+ else:
90
+ assert False, "unknown action: {}".format(line)
91
+ return "MAIN", None, history, task
92
 
93
+
94
+ def call_test(purpose, task, history, directory, action_input):
95
+ result = subprocess.run(
96
+ ["python", "-m", "pytest", "--collect-only", directory],
97
+ capture_output=True,
98
+ text=True,
99
+ )
100
+ if result.returncode != 0:
101
+ history += "observation: there are no tests! Test should be written in a test folder under {}\n".format(
102
+ directory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
+ return "MAIN", None, history, task
105
+ result = subprocess.run(
106
+ ["python", "-m", "pytest", directory], capture_output=True, text=True
107
+ )
108
+ if result.returncode == 0:
109
+ history += "observation: tests pass\n"
110
+ return "MAIN", None, history, task
111
+ module_summary, content, _ = read_python_module_structure(directory)
112
+ resp = run_gpt(
113
+ UNDERSTAND_TEST_RESULTS_PROMPT,
114
+ stop_tokens=[],
115
+ max_tokens=256,
116
+ module_summary=module_summary,
117
+ purpose=purpose,
118
+ task=task,
119
+ history=history,
120
+ stdout=result.stdout[:5000], # limit amount of text
121
+ stderr=result.stderr[:5000], # limit amount of text
122
+ )
123
+ history += "observation: tests failed: {}\n".format(resp)
124
+ return "MAIN", None, history, task
125
+
126
+
127
+ def call_set_task(purpose, task, history, directory, action_input):
128
+ module_summary, content, _ = read_python_module_structure(directory)
129
+ task = run_gpt(
130
+ TASK_PROMPT,
131
+ stop_tokens=[],
132
+ max_tokens=64,
133
+ module_summary=module_summary,
134
+ purpose=purpose,
135
+ task=task,
136
+ history=history,
137
+ ).strip("\n")
138
+ history += "observation: task has been updated to: {}\n".format(task)
139
+ return "MAIN", None, history, task
140
+
141
+
142
+ def call_read(purpose, task, history, directory, action_input):
143
+ if not os.path.exists(action_input):
144
+ history += "observation: file does not exist\n"
145
+ return "MAIN", None, history, task
146
+ module_summary, content, _ = read_python_module_structure(directory)
147
+ f_content = (
148
+ content[action_input] if content[action_input] else "< document is empty >"
149
  )
150
+ resp = run_gpt(
151
+ READ_PROMPT,
152
+ stop_tokens=[],
153
+ max_tokens=256,
154
+ module_summary=module_summary,
155
+ purpose=purpose,
156
+ task=task,
157
+ history=history,
158
+ file_path=action_input,
159
+ file_contents=f_content,
160
+ ).strip("\n")
161
+ history += "observation: {}\n".format(resp)
162
+ return "MAIN", None, history, task
163
 
 
 
164
 
165
+ def call_modify(purpose, task, history, directory, action_input):
166
+ if not os.path.exists(action_input):
167
+ history += "observation: file does not exist\n"
168
+ return "MAIN", None, history, task
169
+ (
170
+ module_summary,
171
+ content,
172
+ _,
173
+ ) = read_python_module_structure(directory)
174
+ f_content = (
175
+ content[action_input] if content[action_input] else "< document is empty >"
176
+ )
177
+ resp = run_gpt(
178
+ MODIFY_PROMPT,
179
+ stop_tokens=["action:", "thought:", "observation:"],
180
+ max_tokens=2048,
181
+ module_summary=module_summary,
182
+ purpose=purpose,
183
+ task=task,
184
+ history=history,
185
+ file_path=action_input,
186
+ file_contents=f_content,
187
+ )
188
+ new_contents, description = parse_file_content(resp)
189
+ if new_contents is None:
190
+ history += "observation: failed to modify file\n"
191
+ return "MAIN", None, history, task
192
+
193
+ with open(action_input, "w") as f:
194
+ f.write(new_contents)
195
+
196
+ history += "observation: file successfully modified\n"
197
+ history += "observation: {}\n".format(description)
198
+ return "MAIN", None, history, task
199
+
200
+
201
+ def call_add(purpose, task, history, directory, action_input):
202
+ d = os.path.dirname(action_input)
203
+ if not d.startswith(directory):
204
+ history += "observation: files must be under directory {}\n".format(directory)
205
+ elif not action_input.endswith(".py"):
206
+ history += "observation: can only write .py files\n"
207
+ else:
208
+ if d and not os.path.exists(d):
209
+ os.makedirs(d)
210
+ if not os.path.exists(action_input):
211
+ module_summary, _, _ = read_python_module_structure(directory)
212
+ resp = run_gpt(
213
+ ADD_PROMPT,
214
+ stop_tokens=["action:", "thought:", "observation:"],
215
+ max_tokens=2048,
216
+ module_summary=module_summary,
217
+ purpose=purpose,
218
+ task=task,
219
+ history=history,
220
+ file_path=action_input,
221
+ )
222
+ new_contents, description = parse_file_content(resp)
223
+ if new_contents is None:
224
+ history += "observation: failed to write file\n"
225
+ return "MAIN", None, history, task
226
+
227
+ with open(action_input, "w") as f:
228
+ f.write(new_contents)
229
+
230
+ history += "observation: file successfully written\n"
231
+ history += "obsertation: {}\n".format(description)
232
+ else:
233
+ history += "observation: file already exists\n"
234
+ return "MAIN", None, history, task
235
+
236
 
237
+ NAME_TO_FUNC = {
238
+ "MAIN": call_main,
239
+ "UPDATE-TASK": call_set_task,
240
+ "MODIFY-FILE": call_modify,
241
+ "READ-FILE": call_read,
242
+ "ADD-FILE": call_add,
243
+ "TEST": call_test,
244
+ }
245
 
246
+
247
+ def run_action(purpose, task, history, directory, action_name, action_input):
248
+ if action_name == "COMPLETE":
249
+ exit(0)
250
+
251
+ # compress the history when it is long
252
+ if len(history.split("\n")) > MAX_HISTORY:
253
+ if VERBOSE:
254
+ print("COMPRESSING HISTORY")
255
+ history = compress_history(purpose, task, history, directory)
256
+
257
+ assert action_name in NAME_TO_FUNC
258
+
259
+ print("RUN: ", action_name, action_input)
260
+ return NAME_TO_FUNC[action_name](purpose, task, history, directory, action_input)
261
+
262
+
263
+ def run(purpose, directory, task=None):
264
+ history = ""
265
+ action_name = "UPDATE-TASK" if task is None else "MAIN"
266
+ action_input = None
267
+ while True:
268
+ print("")
269
+ print("")
270
+ print("---")
271
+ print("purpose:", purpose)
272
+ print("task:", task)
273
+ print("---")
274
+ print(history)
275
+ print("---")
276
+
277
+ action_name, action_input, history, task = run_action(
278
+ purpose,
279
+ task,
280
+ history,
281
+ directory,
282
+ action_name,
283
+ action_input,
284
+ )