import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset

pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True)
all = load_dataset("raminass/full_opinions_1994_2020")
df = pd.DataFrame(all["train"])
choices = []
for index, row in df[df.category == "per_curiam"].iterrows():
    if len(row["text"]) > 1000:
        choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]]))

unique_judges_by_year = (
    df[df.author_name != "per_curiam"].groupby("year_filed")["author_name"].unique()
)
additional_judges = ["Justice Breyer", "Justice Kennedy"]
unique_judges_by_year[1994] = list(unique_judges_by_year[1994]) + additional_judges


# https://www.gradio.app/guides/controlling-layout
def greet(opinion, judges_l):
    chunks = chunk_data(remove_citations(opinion))["text"].to_list()
    result = average_text(chunks, pipe, judges_l)

    return result[0]


def set_input(drop):
    return drop[0], drop[1], gr.Slider(visible=True)


def update_year(year):
    return gr.CheckboxGroup(
        unique_judges_by_year[year].tolist(),
        value=unique_judges_by_year[year].tolist(),
        label="Select Judges",
    )


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            drop = gr.Dropdown(
                choices=sorted(choices),
                label="Per Curiam Opinions",
                info="Select a per curiam opinion to use as input",
            )
            year = gr.Slider(
                1994,
                2020,
                step=1,
                label="Year",
                info="Select the year of the opinion if you manually pass the opinion below",
            )
            exc_judg = gr.CheckboxGroup(
                unique_judges_by_year[year.value],
                value=unique_judges_by_year[year.value],
                label="Select Judges",
                info="Select judges to consider in prediction",
            )

            opinion = gr.Textbox(
                label="Opinion", info="Paste opinion text here or select from dropdown"
            )
        with gr.Column():
            with gr.Row():
                clear_btn = gr.Button("Clear")
                greet_btn = gr.Button("Predict")
            op_level = gr.outputs.Label(
                num_top_classes=9, label="Predicted author of opinion"
            )

    year.release(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    year.change(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    drop.select(set_input, inputs=drop, outputs=[opinion, year, year])

    greet_btn.click(
        fn=greet,
        inputs=[opinion, exc_judg],
        outputs=[op_level],
    )

    clear_btn.click(
        fn=lambda: [None, 1994, gr.Slider(visible=True), None, None],
        outputs=[opinion, year, year, drop, op_level],
    )


if __name__ == "__main__":
    demo.launch(debug=True)