demo / examples /code_editor_scripts.py
huangzhii
bug fixedd
7d02ab7
raw
history blame
7.09 kB
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os
class CodeEditor:
def __init__(self, data) -> None:
self.data = data
# Initialize only if not already set to ensure it retains the original content
if 'original_code_content' not in st.session_state:
st.session_state.original_code_content = self.data["default_initial_solution"]
self.llm_engine = tg.get_engine("gpt-4o")
print("="*50, "init", "="*50)
self.loss_value = ""
self.code_gradients = ""
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
if 'results' not in st.session_state:
st.session_state.results = []
tg.set_backward_engine(self.llm_engine, override=True)
def load_layout(self):
# Initialize session state for problem description and other fields if not already set
if 'problem' not in st.session_state:
st.session_state.problem = self.data["default_problem_description"]
if 'loss_system_prompt' not in st.session_state:
st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"]
if 'instruction' not in st.session_state:
st.session_state.instruction = self.data["instruction"]
col1, col2 = st.columns([1, 1])
with col1:
st.session_state.problem = st.text_area("Problem description:", st.session_state.problem, height=300)
with col2:
st.session_state.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=150)
st.session_state.instruction = st.text_area("Instruction for formatted LLM call:", st.session_state.instruction, height=100)
# Assume the code content also needs to be persistent
if 'code_content' not in st.session_state:
st.session_state.code_content = self.data["default_initial_solution"]
def update_code_content(value):
if st.session_state.iteration == 0:
st.session_state.code_content = value
# print(f"Code updated: {st.session_state.code_content}")
col1, col2 = st.columns(2)
with col1:
with elements("monaco_editors_widget_original"):
st.markdown(f"**Initial solution:**")
# code = editor.Monaco(
# height=300,
# defaultLanguage="python",
# defaultValue=st.session_state.original_code_content,
# onChange=update_code_content,
# label="Initial Solution Viewer",
# )
code = st.text_area("Edit your code here:", st.session_state.original_code_content, height=300)
# Update session state when text changes
if code is not None and st.session_state.original_code_content != code:
update_code_content(code)
# if st.session_state.code_content != code:
# update_code_content(code)
# with col2:
def _run(self):
# Code is the variable of interest we want to optimize -- so requires_grad=True
solution = st.session_state.code_content
code = tg.Variable(value=solution,
requires_grad=True,
role_description="code instance to optimize")
# We are not interested in optimizing the problem -- so requires_grad=False
problem = tg.Variable(st.session_state.problem,
requires_grad=False,
role_description="the coding problem")
# Let TGD know to update code!
optimizer = tg.TGD(parameters=[code])
instruction = st.session_state.instruction
llm_engine = self.llm_engine
loss_system_prompt = st.session_state.loss_system_prompt
loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function")
format_string = "{instruction}\nProblem: {{problem}}\nCurrent Code: {{code}}"
format_string = format_string.format(instruction=st.session_state.instruction)
fields = {"problem": None, "code": None}
formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine,
format_string=format_string,
fields=fields,
system_prompt=loss_system_prompt)
# Finally, the loss function
def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable:
inputs = {"problem": problem, "code": code}
return formatted_llm_call(inputs=inputs,
response_role_description=f"evaluation of the {code.get_role_description()}")
loss = loss_fn(problem, code)
self.loss_value = loss.value
self.graph = loss.generate_graph()
loss.backward()
self.gradients = code.gradients
optimizer.step() # Let's update the code
st.session_state.code_content = code.value
def show_results(self):
self._run()
st.session_state.iteration += 1
st.session_state.results.append({
'iteration': st.session_state.iteration,
'loss_value': self.loss_value,
'gradients': self.gradients,
'code_content': st.session_state.code_content,
})
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])
# Include Highlight.js library and a theme CSS
st.markdown("""
<link rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
""", unsafe_allow_html=True)
for i, tab in enumerate(tabs):
with tab:
result = st.session_state.results[i]
st.markdown(f"Current iteration: **{result['iteration']}**")
st.markdown("### Current solution")
st.markdown(f"""
<pre><code class="language-python">{result["code_content"]}</code></pre>
""", unsafe_allow_html=True)
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### Loss value")
st.markdown("**Loss value is based on previous code.**")
st.markdown(result['loss_value'])
with col2:
st.markdown("### Code gradients")
for j, g in enumerate(result['gradients']):
# st.markdown(f"### Gradient {j}")
st.markdown(g.value)