File size: 4,078 Bytes
32486dc
 
 
 
 
 
 
 
 
 
7d02ab7
 
 
 
 
 
 
 
 
32486dc
 
 
 
 
 
 
 
 
 
7d02ab7
 
 
 
 
 
32486dc
 
7d02ab7
32486dc
7d02ab7
 
 
 
 
 
32486dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d02ab7
 
32486dc
 
7d02ab7
32486dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os


class MathSolution:
    def __init__(self, data) -> None:
        self.data = data
        if 'default_initial_solution' not in st.session_state:
            st.session_state.default_initial_solution = self.data["default_initial_solution"]
        if 'loss_system_prompt' not in st.session_state:
            st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"]
        if 'instruction' not in st.session_state:
            st.session_state.instruction = self.data["instruction"]
        if 'current_solution' not in st.session_state:
            st.session_state.current_solution = self.data["default_initial_solution"]

        self.llm_engine = tg.get_engine("gpt-4o")
        print("="*50, "init", "="*50)
        self.loss_value = ""
        self.gradients = ""
        if 'iteration' not in st.session_state:
            st.session_state.iteration = 0
        st.session_state.results = []
        tg.set_backward_engine(self.llm_engine, override=True)

    def load_layout(self):

        def update_solution_content(value):
            if st.session_state.iteration == 0:
                st.session_state.current_solution = value
            # print(f"Code updated: {st.session_state.code_content}")

        col1, col2 = st.columns([1, 1])
        with col1:
            # self.initial_solution = st.text_area("Initial solution:", self.data["default_initial_solution"], height=300)

            solution = st.text_area("Initial solution:", st.session_state.default_initial_solution, height=300)
            # Update session state when text changes
            if solution is not None and st.session_state.default_initial_solution != solution:
                update_solution_content(solution)
        with col2:
            self.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=300)

    def _run(self):
        # Set up the textgrad variables
        current_solution = st.session_state.current_solution

        self.response = tg.Variable(current_solution,
                            requires_grad=True,
                            role_description="solution to the math question")

        loss_fn = tg.TextLoss(tg.Variable(self.loss_system_prompt,
                                        requires_grad=False,
                                        role_description="system prompt"))
        optimizer = tg.TGD([self.response])

        loss = loss_fn(self.response)
        self.loss_value = loss.value
        self.graph = loss.generate_graph()

        loss.backward()
        self.gradients = self.response.gradients

        optimizer.step() # Let's update the solution
        st.session_state.current_solution = self.response.value

    def show_results(self):
        self._run()
        st.session_state.iteration += 1
        st.session_state.results.append({
            'iteration': st.session_state.iteration,
            'loss_value': self.loss_value,
            'response': self.response.value,
            'gradients': self.gradients
        })

        tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])

        for i, tab in enumerate(tabs):
            with tab:
                result = st.session_state.results[i]
                st.markdown(f"Current iteration: **{result['iteration']}**")
                st.markdown("## Current solution:")
                st.markdown(result['response'])
                col1, col2 = st.columns([1, 1])
                with col1:
                    st.markdown("### Loss value")
                    st.markdown("**Loss value is based on previous code.**")
                    st.markdown(result['loss_value'])
                with col2:
                    st.markdown("### Code gradients")
                    for j, g in enumerate(result['gradients']):
                        st.markdown(f"### Gradient")
                        st.markdown(g.value)