|
import streamlit as st |
|
from streamlit_elements import elements, mui, editor, dashboard |
|
from stqdm import stqdm |
|
import textgrad as tg |
|
import os |
|
|
|
|
|
class MathSolution: |
|
def __init__(self, data) -> None: |
|
self.data = data |
|
self.llm_engine = tg.get_engine("gpt-4o") |
|
print("="*50, "init", "="*50) |
|
self.loss_value = "" |
|
self.gradients = "" |
|
if 'iteration' not in st.session_state: |
|
st.session_state.iteration = 0 |
|
st.session_state.results = [] |
|
tg.set_backward_engine(self.llm_engine, override=True) |
|
|
|
def load_layout(self): |
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
self.initial_solution = st.text_area("Initial solution:", self.data["default_initial_solution"], height=300) |
|
with col2: |
|
self.loss_system_prompt = st.text_area("Loss system prompt:", self.data["default_loss_system_prompt"], height=300) |
|
|
|
if "current_solution" not in st.session_state: |
|
st.session_state.current_solution = self.data["default_initial_solution"] |
|
|
|
|
|
def _run(self): |
|
|
|
current_solution = st.session_state.current_solution |
|
|
|
self.response = tg.Variable(current_solution, |
|
requires_grad=True, |
|
role_description="solution to the math question") |
|
|
|
loss_fn = tg.TextLoss(tg.Variable(self.loss_system_prompt, |
|
requires_grad=False, |
|
role_description="system prompt")) |
|
optimizer = tg.TGD([self.response]) |
|
|
|
loss = loss_fn(self.response) |
|
self.loss_value = loss.value |
|
self.graph = loss.generate_graph() |
|
|
|
loss.backward() |
|
self.gradients = self.response.gradients |
|
|
|
optimizer.step() |
|
st.session_state.current_solution = self.response.value |
|
|
|
def show_results(self): |
|
self._run() |
|
st.session_state.iteration += 1 |
|
st.session_state.results.append({ |
|
'iteration': st.session_state.iteration, |
|
'loss_value': self.loss_value, |
|
'response': self.response.value, |
|
'gradients': self.gradients |
|
}) |
|
|
|
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)]) |
|
|
|
for i, tab in enumerate(tabs): |
|
with tab: |
|
result = st.session_state.results[i] |
|
st.markdown(f"Current iteration: **{result['iteration']}**") |
|
st.markdown("## Current solution:") |
|
st.markdown(result['response']) |
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
st.markdown("## Loss value") |
|
st.markdown(result['loss_value']) |
|
with col2: |
|
st.markdown("## Code gradients") |
|
for j, g in enumerate(result['gradients']): |
|
st.markdown(f"### Gradient") |
|
st.markdown(g.value) |