|
import streamlit as st |
|
from streamlit_elements import elements, mui, editor, dashboard |
|
from stqdm import stqdm |
|
import textgrad as tg |
|
import os |
|
|
|
class CodeEditor: |
|
def __init__(self, data) -> None: |
|
self.data = data |
|
|
|
if 'original_code_content' not in st.session_state: |
|
st.session_state.original_code_content = self.data["default_initial_solution"] |
|
|
|
self.llm_engine = tg.get_engine("gpt-4o") |
|
print("="*50, "init", "="*50) |
|
self.loss_value = "" |
|
self.code_gradients = "" |
|
if 'iteration' not in st.session_state: |
|
st.session_state.iteration = 0 |
|
if 'results' not in st.session_state: |
|
st.session_state.results = [] |
|
tg.set_backward_engine(self.llm_engine, override=True) |
|
|
|
|
|
def load_layout(self): |
|
|
|
if 'problem' not in st.session_state: |
|
st.session_state.problem = self.data["default_problem_description"] |
|
if 'loss_system_prompt' not in st.session_state: |
|
st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"] |
|
if 'instruction' not in st.session_state: |
|
st.session_state.instruction = self.data["instruction"] |
|
|
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
st.session_state.problem = st.text_area("Problem description:", st.session_state.problem, height=300) |
|
with col2: |
|
st.session_state.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=150) |
|
st.session_state.instruction = st.text_area("Instruction for formatted LLM call:", st.session_state.instruction, height=100) |
|
|
|
|
|
if 'code_content' not in st.session_state: |
|
st.session_state.code_content = self.data["default_initial_solution"] |
|
|
|
def update_code_content(value): |
|
if st.session_state.iteration == 0: |
|
st.session_state.code_content = value |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
with elements("monaco_editors_widget_original"): |
|
st.markdown(f"**Initial solution:**") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code = st.text_area("Edit your code here:", st.session_state.original_code_content, height=300) |
|
|
|
if code is not None and st.session_state.original_code_content != code: |
|
update_code_content(code) |
|
|
|
|
|
|
|
|
|
|
|
def _run(self): |
|
|
|
solution = st.session_state.code_content |
|
code = tg.Variable(value=solution, |
|
requires_grad=True, |
|
role_description="code instance to optimize") |
|
|
|
|
|
problem = tg.Variable(st.session_state.problem, |
|
requires_grad=False, |
|
role_description="the coding problem") |
|
|
|
|
|
optimizer = tg.TGD(parameters=[code]) |
|
|
|
|
|
instruction = st.session_state.instruction |
|
llm_engine = self.llm_engine |
|
loss_system_prompt = st.session_state.loss_system_prompt |
|
loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function") |
|
|
|
format_string = "{instruction}\nProblem: {{problem}}\nCurrent Code: {{code}}" |
|
format_string = format_string.format(instruction=st.session_state.instruction) |
|
|
|
fields = {"problem": None, "code": None} |
|
formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine, |
|
format_string=format_string, |
|
fields=fields, |
|
system_prompt=loss_system_prompt) |
|
|
|
def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable: |
|
inputs = {"problem": problem, "code": code} |
|
|
|
return formatted_llm_call(inputs=inputs, |
|
response_role_description=f"evaluation of the {code.get_role_description()}") |
|
loss = loss_fn(problem, code) |
|
self.loss_value = loss.value |
|
self.graph = loss.generate_graph() |
|
|
|
loss.backward() |
|
self.gradients = code.gradients |
|
optimizer.step() |
|
|
|
st.session_state.code_content = code.value |
|
|
|
def show_results(self): |
|
self._run() |
|
st.session_state.iteration += 1 |
|
st.session_state.results.append({ |
|
'iteration': st.session_state.iteration, |
|
'loss_value': self.loss_value, |
|
'gradients': self.gradients, |
|
'code_content': st.session_state.code_content, |
|
}) |
|
|
|
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)]) |
|
|
|
|
|
st.markdown(""" |
|
<link rel="stylesheet" |
|
href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/styles/default.min.css"> |
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/highlight.min.js"></script> |
|
<script>hljs.highlightAll();</script> |
|
""", unsafe_allow_html=True) |
|
|
|
for i, tab in enumerate(tabs): |
|
with tab: |
|
result = st.session_state.results[i] |
|
st.markdown(f"Current iteration: **{result['iteration']}**") |
|
|
|
st.markdown("### Current solution") |
|
st.markdown(f""" |
|
<pre><code class="language-python">{result["code_content"]}</code></pre> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
st.markdown("### Loss value") |
|
st.markdown("**Loss value is based on previous code.**") |
|
st.markdown(result['loss_value']) |
|
with col2: |
|
st.markdown("### Code gradients") |
|
for j, g in enumerate(result['gradients']): |
|
|
|
st.markdown(g.value) |