import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset
import json

pipe = pipeline(model="raminass/94-23", top_k=17, padding=True, truncation=True)
all = load_dataset("raminass/full_opinions_1994_2020")
df = pd.DataFrame(all["train"])
choices = []
for index, row in df[df.category == "per_curiam"].iterrows():
    if len(row["text"]) > 1000:
        choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]]))

with open("j_year.json", "r") as j:
    judges_by_year = json.loads(j.read())
judges_by_year = {int(k): v for k, v in judges_by_year.items()}


# https://www.gradio.app/guides/controlling-layout
def greet(opinion, judges_l):
    chunks = chunk_data(remove_citations(opinion))["text"].to_list()
    result = average_text(chunks, pipe, judges_l)

    return result[0]


def set_input(drop):
    return drop[0], drop[1], gr.Slider(visible=True)


def update_year(year):
    return gr.CheckboxGroup(
        judges_by_year[year],
        value=judges_by_year[year],
        label="Select Justices",
    )


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            drop = gr.Dropdown(
                choices=sorted(choices),
                label="List of Per Curiam Opinions",
                info="Select a per curiam opinion from the dropdown menu and press the Predict Button",
            )
            year = gr.Slider(
                1994,
                2023,
                step=1,
                label="Year",
                info="Select the year of the opinion if you manually pass the opinion below",
            )
            exc_judg = gr.CheckboxGroup(
                judges_by_year[year.value],
                value=judges_by_year[year.value],
                label="Select Justices",
                info="Select justices to consider in prediction",
            )

            opinion = gr.Textbox(
                label="Opinion", info="Paste opinion text here and press the Predict Button"
            )
        with gr.Column():
            with gr.Row():
                clear_btn = gr.Button("Clear")
                greet_btn = gr.Button("Predict")
            op_level = gr.outputs.Label(
                num_top_classes=9, label="Predicted author of opinion"
            )

    year.release(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    year.change(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    drop.select(set_input, inputs=drop, outputs=[opinion, year, year])

    greet_btn.click(
        fn=greet,
        inputs=[opinion, exc_judg],
        outputs=[op_level],
    )

    clear_btn.click(
        fn=lambda: [None, 1994, gr.Slider(visible=True), None, None],
        outputs=[opinion, year, year, drop, op_level],
    )


if __name__ == "__main__":
    demo.launch()