demo / examples /example_math_scripts.py
huangzhii
bug fixedd
7d02ab7
raw
history blame
4.08 kB
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os
class MathSolution:
def __init__(self, data) -> None:
self.data = data
if 'default_initial_solution' not in st.session_state:
st.session_state.default_initial_solution = self.data["default_initial_solution"]
if 'loss_system_prompt' not in st.session_state:
st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"]
if 'instruction' not in st.session_state:
st.session_state.instruction = self.data["instruction"]
if 'current_solution' not in st.session_state:
st.session_state.current_solution = self.data["default_initial_solution"]
self.llm_engine = tg.get_engine("gpt-4o")
print("="*50, "init", "="*50)
self.loss_value = ""
self.gradients = ""
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
st.session_state.results = []
tg.set_backward_engine(self.llm_engine, override=True)
def load_layout(self):
def update_solution_content(value):
if st.session_state.iteration == 0:
st.session_state.current_solution = value
# print(f"Code updated: {st.session_state.code_content}")
col1, col2 = st.columns([1, 1])
with col1:
# self.initial_solution = st.text_area("Initial solution:", self.data["default_initial_solution"], height=300)
solution = st.text_area("Initial solution:", st.session_state.default_initial_solution, height=300)
# Update session state when text changes
if solution is not None and st.session_state.default_initial_solution != solution:
update_solution_content(solution)
with col2:
self.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=300)
def _run(self):
# Set up the textgrad variables
current_solution = st.session_state.current_solution
self.response = tg.Variable(current_solution,
requires_grad=True,
role_description="solution to the math question")
loss_fn = tg.TextLoss(tg.Variable(self.loss_system_prompt,
requires_grad=False,
role_description="system prompt"))
optimizer = tg.TGD([self.response])
loss = loss_fn(self.response)
self.loss_value = loss.value
self.graph = loss.generate_graph()
loss.backward()
self.gradients = self.response.gradients
optimizer.step() # Let's update the solution
st.session_state.current_solution = self.response.value
def show_results(self):
self._run()
st.session_state.iteration += 1
st.session_state.results.append({
'iteration': st.session_state.iteration,
'loss_value': self.loss_value,
'response': self.response.value,
'gradients': self.gradients
})
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])
for i, tab in enumerate(tabs):
with tab:
result = st.session_state.results[i]
st.markdown(f"Current iteration: **{result['iteration']}**")
st.markdown("## Current solution:")
st.markdown(result['response'])
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### Loss value")
st.markdown("**Loss value is based on previous code.**")
st.markdown(result['loss_value'])
with col2:
st.markdown("### Code gradients")
for j, g in enumerate(result['gradients']):
st.markdown(f"### Gradient")
st.markdown(g.value)