File size: 18,207 Bytes
23e06a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import asyncio
import json
from typing import Iterable, Tuple

# https://github.com/jerryjliu/llama_index/issues/7244:
asyncio.set_event_loop(asyncio.new_event_loop())

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from st_aggrid import AgGrid
from st_aggrid.grid_options_builder import GridOptionsBuilder
from st_aggrid.shared import GridUpdateMode
from st_aggrid.shared import JsCode
import streamlit as st
from ux.add_logo import add_logo_and_style_overrides
from ux.styles import CATEGORY

from trulens_eval import Tru
from trulens_eval.app import Agent
from trulens_eval.app import ComponentView
from trulens_eval.app import instrumented_component_views
from trulens_eval.app import LLM
from trulens_eval.app import Other
from trulens_eval.app import Prompt
from trulens_eval.app import Tool
from trulens_eval.db import MULTI_CALL_NAME_DELIMITER
from trulens_eval.react_components.record_viewer import record_viewer
from trulens_eval.schema import Record
from trulens_eval.schema import Select
from trulens_eval.utils.json import jsonify_for_ui
from trulens_eval.utils.serial import Lens
from trulens_eval.ux.components import draw_agent_info
from trulens_eval.ux.components import draw_call
from trulens_eval.ux.components import draw_llm_info
from trulens_eval.ux.components import draw_metadata
from trulens_eval.ux.components import draw_prompt_info
from trulens_eval.ux.components import draw_tool_info
from trulens_eval.ux.components import render_selector_markdown
from trulens_eval.ux.components import write_or_json
from trulens_eval.ux.styles import cellstyle_jscode

st.set_page_config(page_title="Evaluations", layout="wide")

st.title("Evaluations")

st.runtime.legacy_caching.clear_cache()

add_logo_and_style_overrides()

tru = Tru()
lms = tru.db

df_results, feedback_cols = lms.get_records_and_feedback([])

# TODO: remove code redundancy / redundant database calls
feedback_directions = {
    (
        row.feedback_json.get("supplied_name", "") or
        row.feedback_json["implementation"]["name"]
    ): (
        "HIGHER_IS_BETTER" if row.feedback_json.get("higher_is_better", True)
        else "LOWER_IS_BETTER"
    ) for _, row in lms.get_feedback_defs().iterrows()
}
default_direction = "HIGHER_IS_BETTER"


def render_component(query, component, header=True):
    # Draw the accessor/path within the wrapped app of the component.
    if header:
        st.markdown(
            f"##### Component {render_selector_markdown(Select.for_app(query))}"
        )

    # Draw the python class information of this component.
    cls = component.cls
    base_cls = cls.base_class()
    label = f"__{repr(cls)}__"
    if str(base_cls) != str(cls):
        label += f" < __{repr(base_cls)}__"
    st.write("Python class: " + label)

    # Per-component-type drawing routines.
    if isinstance(component, LLM):
        draw_llm_info(component=component, query=query)

    elif isinstance(component, Prompt):
        draw_prompt_info(component=component, query=query)

    elif isinstance(component, Agent):
        draw_agent_info(component=component, query=query)

    elif isinstance(component, Tool):
        draw_tool_info(component=component, query=query)

    elif isinstance(component, Other):
        with st.expander("Uncategorized Component Details:"):
            st.json(jsonify_for_ui(component.json))

    else:
        with st.expander("Unhandled Component Details:"):
            st.json(jsonify_for_ui(component.json))


# Renders record level metrics (e.g. total tokens, cost, latency) compared to the average when appropriate
def render_record_metrics(app_df: pd.DataFrame, selected_rows: pd.DataFrame):
    app_specific_df = app_df[app_df["app_id"] == selected_rows["app_id"][0]]

    token_col, cost_col, latency_col = st.columns(3)

    num_tokens = selected_rows["total_tokens"][0]
    token_col.metric(label="Total tokens (#)", value=num_tokens)

    cost = selected_rows["total_cost"][0]
    average_cost = app_specific_df["total_cost"].mean()
    delta_cost = "{:.3g}".format(cost - average_cost)
    cost_col.metric(
        label="Total cost (USD)",
        value=selected_rows["total_cost"][0],
        delta=delta_cost,
        delta_color="inverse",
    )

    latency = selected_rows["latency"][0]
    average_latency = app_specific_df["latency"].mean()
    delta_latency = "{:.3g}s".format(latency - average_latency)
    latency_col.metric(
        label="Latency (s)",
        value=selected_rows["latency"][0],
        delta=delta_latency,
        delta_color="inverse",
    )


if df_results.empty:
    st.write("No records yet...")

else:
    apps = list(df_results.app_id.unique())
    if "app" in st.session_state:
        app = st.session_state.app
    else:
        app = apps

    st.experimental_set_query_params(app=app)

    options = st.multiselect("Filter Applications", apps, default=app)

    if len(options) == 0:
        st.header("All Applications")
        app_df = df_results

    elif len(options) == 1:
        st.header(options[0])

        app_df = df_results[df_results.app_id.isin(options)]

    else:
        st.header("Multiple Applications Selected")

        app_df = df_results[df_results.app_id.isin(options)]

    tab1, tab2 = st.tabs(["Records", "Feedback Functions"])

    with tab1:
        gridOptions = {"alwaysShowHorizontalScroll": True}
        evaluations_df = app_df

        # By default the cells in the df are unicode-escaped, so we have to reverse it.
        input_array = evaluations_df['input'].to_numpy()
        output_array = evaluations_df['output'].to_numpy()

        decoded_input = np.vectorize(
            lambda x: x.encode('utf-8').decode('unicode-escape')
        )(input_array)
        decoded_output = np.vectorize(
            lambda x: x.encode('utf-8').decode('unicode-escape')
        )(output_array)

        evaluations_df['input'] = decoded_input
        evaluations_df['output'] = decoded_output

        gb = GridOptionsBuilder.from_dataframe(evaluations_df)

        gb.configure_column("type", header_name="App Type")
        gb.configure_column("record_json", header_name="Record JSON", hide=True)
        gb.configure_column("app_json", header_name="App JSON", hide=True)
        gb.configure_column("cost_json", header_name="Cost JSON", hide=True)
        gb.configure_column("perf_json", header_name="Perf. JSON", hide=True)

        gb.configure_column("record_id", header_name="Record ID", hide=True)
        gb.configure_column("app_id", header_name="App ID")

        gb.configure_column("feedback_id", header_name="Feedback ID", hide=True)
        gb.configure_column("input", header_name="User Input")
        gb.configure_column(
            "output",
            header_name="Response",
        )
        gb.configure_column("total_tokens", header_name="Total Tokens (#)")
        gb.configure_column("total_cost", header_name="Total Cost (USD)")
        gb.configure_column("latency", header_name="Latency (Seconds)")
        gb.configure_column("tags", header_name="Tags")
        gb.configure_column("ts", header_name="Time Stamp", sort="desc")

        non_feedback_cols = [
            "app_id",
            "type",
            "ts",
            "total_tokens",
            "total_cost",
            "record_json",
            "latency",
            "record_id",
            "app_id",
            "cost_json",
            "app_json",
            "input",
            "output",
            "perf_json",
        ]

        for feedback_col in evaluations_df.columns.drop(non_feedback_cols):
            if "distance" in feedback_col:
                gb.configure_column(
                    feedback_col, hide=feedback_col.endswith("_calls")
                )
            else:
                # cell highlight depending on feedback direction
                cellstyle = JsCode(
                    cellstyle_jscode[feedback_directions.get(
                        feedback_col, default_direction
                    )]
                )

                gb.configure_column(
                    feedback_col,
                    cellStyle=cellstyle,
                    hide=feedback_col.endswith("_calls")
                )

        gb.configure_pagination()
        gb.configure_side_bar()
        gb.configure_selection(selection_mode="single", use_checkbox=False)
        # gb.configure_default_column(groupable=True, value=True, enableRowGroup=True, aggFunc="sum", editable=True)
        gridOptions = gb.build()
        data = AgGrid(
            evaluations_df,
            gridOptions=gridOptions,
            update_mode=GridUpdateMode.SELECTION_CHANGED,
            allow_unsafe_jscode=True,
        )

        selected_rows = data["selected_rows"]
        selected_rows = pd.DataFrame(selected_rows)

        if len(selected_rows) == 0:
            st.write("Hint: select a row to display details of a record")

        else:
            # Start the record specific section
            st.divider()

            # Breadcrumbs
            st.caption(
                f"{selected_rows['app_id'][0]} / {selected_rows['record_id'][0]}"
            )
            st.header(f"{selected_rows['record_id'][0]}")

            render_record_metrics(app_df, selected_rows)

            st.markdown("")

            prompt = selected_rows["input"][0]
            response = selected_rows["output"][0]
            details = selected_rows["app_json"][0]

            app_json = json.loads(
                details
            )  # apps may not be deserializable, don't try to, keep it json.

            row = selected_rows.head().iloc[0]

            # Display input/response side by side. In each column, we put them in tabs mainly for
            # formatting/styling purposes.
            input_col, response_col = st.columns(2)

            (input_tab,) = input_col.tabs(["Input"])
            with input_tab:
                with st.expander(
                        f"Input {render_selector_markdown(Select.RecordInput)}",
                        expanded=True):
                    write_or_json(st, obj=prompt)

            (response_tab,) = response_col.tabs(["Response"])
            with response_tab:
                with st.expander(
                        f"Response {render_selector_markdown(Select.RecordOutput)}",
                        expanded=True):
                    write_or_json(st, obj=response)

            feedback_tab, metadata_tab = st.tabs(["Feedback", "Metadata"])

            with metadata_tab:
                metadata = app_json.get("metadata")
                if metadata:
                    with st.expander("Metadata"):
                        st.markdown(draw_metadata(metadata))
                else:
                    st.write("No metadata found")

            with feedback_tab:
                if len(feedback_cols) == 0:
                    st.write("No feedback details")

                for fcol in feedback_cols:
                    feedback_name = fcol
                    feedback_result = row[fcol]
                    print(feedback_result)

                    if MULTI_CALL_NAME_DELIMITER in fcol:
                        fcol = fcol.split(MULTI_CALL_NAME_DELIMITER)[0]
                    feedback_calls = row[f"{fcol}_calls"]

                    def display_feedback_call(call):

                        def highlight(s):
                            if "distance" in feedback_name:
                                return [
                                    f"background-color: {CATEGORY.UNKNOWN.color}"
                                ] * len(s)
                            cat = CATEGORY.of_score(
                                s.result,
                                higher_is_better=feedback_directions.get(
                                    fcol, default_direction
                                ) == default_direction
                            )
                            return [f"background-color: {cat.color}"] * len(s)

                        if call is not None and len(call) > 0:
                            df = pd.DataFrame.from_records(
                                [call[i]["args"] for i in range(len(call))]
                            )
                            df["result"] = pd.DataFrame(
                                [
                                    float(call[i]["ret"])
                                    if call[i]["ret"] is not None else -1
                                    for i in range(len(call))
                                ]
                            )
                            df["meta"] = pd.Series(
                                [call[i]["meta"] for i in range(len(call))]
                            )
                            df = df.join(df.meta.apply(lambda m: pd.Series(m))
                                        ).drop(columns="meta")

                            st.dataframe(
                                df.style.apply(highlight, axis=1).format(
                                    "{:.2}", subset=["result"]
                                )
                            )

                        else:
                            st.text("No feedback details.")

                    with st.expander(f"{feedback_name} = {feedback_result}",
                                     expanded=True):
                        display_feedback_call(feedback_calls)

            record_str = selected_rows["record_json"][0]
            record_json = json.loads(record_str)
            record = Record.model_validate(record_json)

            classes: Iterable[Tuple[Lens, ComponentView]
                             ] = list(instrumented_component_views(app_json))
            classes_map = {path: view for path, view in classes}

            st.markdown("")
            st.subheader("Timeline")
            val = record_viewer(record_json, app_json)
            st.markdown("")

            match_query = None

            # Assumes record_json['perf']['start_time'] is always present
            if val != "":
                match = None
                for call in record.calls:
                    if call.perf.start_time.isoformat() == val:
                        match = call
                        break

                if match:
                    length = len(match.stack)
                    app_call = match.stack[length - 1]

                    match_query = match.top().path

                    st.subheader(
                        f"{app_call.method.obj.cls.name} {render_selector_markdown(Select.for_app(match_query))}"
                    )

                    draw_call(match)

                    view = classes_map.get(match_query)
                    if view is not None:
                        render_component(
                            query=match_query, component=view, header=False
                        )
                    else:
                        st.write(
                            f"Call by `{match_query}` was not associated with any instrumented"
                            " component."
                        )
                        # Look up whether there was any data at that path even if not an instrumented component:

                        try:
                            app_component_json = list(
                                match_query.get(app_json)
                            )[0]
                            if app_component_json is not None:
                                with st.expander(
                                        "Uninstrumented app component details."
                                ):
                                    st.json(app_component_json)
                        except Exception:
                            st.write(
                                f"Recorded invocation by component `{match_query}` but cannot find this component in the app json."
                            )

                else:
                    st.text("No match found")
            else:
                st.subheader(f"App {render_selector_markdown(Select.App)}")
                with st.expander("App Details:"):
                    st.json(jsonify_for_ui(app_json))

            if match_query is not None:
                container = st.empty()

                has_subcomponents = False
                for query, component in classes:
                    if not match_query.is_immediate_prefix_of(query):
                        continue

                    if len(query.path) == 0:
                        # Skip App, will still list App.app under "app".
                        continue

                    has_subcomponents = True
                    render_component(query, component)

                if has_subcomponents:
                    container.markdown("#### Subcomponents:")

            st.header("More options:")

            if st.button("Display full app json"):
                st.write(jsonify_for_ui(app_json))

            if st.button("Display full record json"):
                st.write(jsonify_for_ui(record_json))

    with tab2:
        feedback = feedback_cols
        cols = 4
        rows = len(feedback) // cols + 1

        for row_num in range(rows):
            with st.container():
                columns = st.columns(cols)
                for col_num in range(cols):
                    with columns[col_num]:
                        ind = row_num * cols + col_num
                        if ind < len(feedback):
                            # Generate histogram
                            fig, ax = plt.subplots()
                            bins = [
                                0, 0.2, 0.4, 0.6, 0.8, 1.0
                            ]  # Quintile buckets
                            ax.hist(
                                app_df[feedback[ind]],
                                bins=bins,
                                edgecolor="black",
                                color="#2D736D"
                            )
                            ax.set_xlabel("Feedback Value")
                            ax.set_ylabel("Frequency")
                            ax.set_title(feedback[ind], loc="center")
                            st.pyplot(fig)