File size: 3,107 Bytes
f80cc50
8ddc567
72f0cff
71645c3
8fbc997
72f0cff
5447c0b
1bfa557
 
5447c0b
71645c3
 
 
 
926ab2a
 
f80cc50
8fbc997
 
 
f80cc50
8ddc567
 
9c55aba
8ddc567
fb7fb6c
9c55aba
 
8ddc567
f80cc50
71645c3
9c55aba
 
 
 
 
8fbc997
 
a4466d4
9c55aba
71645c3
 
f80cc50
fb7fb6c
eecd399
9c55aba
 
a4466d4
 
9c55aba
 
 
8fbc997
9c55aba
 
0805275
9c55aba
 
8fbc997
 
f1d21d9
 
9c55aba
 
 
a4466d4
9c55aba
eecd399
926ab2a
 
 
9c55aba
 
fb7fb6c
8ddc567
9c55aba
 
 
 
 
 
 
 
 
 
926ab2a
71645c3
8ddc567
 
9c55aba
 
8ddc567
 
926ab2a
9c55aba
 
926ab2a
 
f80cc50
 
8fbc997
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset
import json

# pipe = pipeline(model="raminass/SCOTUS_AI_15", top_k=15, padding=True, truncation=True)
# pipe = pipeline(model="raminass/SCOTUS_AI_V15_CURCUIT", top_k=15, padding=True, truncation=True)
pipe = pipeline(model="raminass/SCOTUS_AI_V15_CURCUIT_V2", top_k=15, padding=True, truncation=True)

all = load_dataset("raminass/full_opinions_1994_2020")
df = pd.DataFrame(all["train"])
choices = []
for index, row in df[df.category == "per_curiam"].iterrows():
    if len(row["text"]) > 1000:
        choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]]))

with open("j_year.json", "r") as j:
    judges_by_year = json.loads(j.read())
judges_by_year = {int(k): v for k, v in judges_by_year.items()}


# https://www.gradio.app/guides/controlling-layout
def greet(opinion, judges_l):
    chunks = chunk_data(remove_citations(opinion))["text"].to_list()
    result = average_text(chunks, pipe, judges_l)

    return result[0]


def set_input(drop):
    return drop[0], drop[1], gr.Slider(visible=True)


def update_year(year):
    return gr.CheckboxGroup(
        judges_by_year[year],
        value=judges_by_year[year],
        label="Select Justices",
    )


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=2):
            drop = gr.Dropdown(
                choices=sorted(choices),
                label="List of Per Curiam Opinions",
                info="Select a per curiam opinion from the dropdown menu and press the Predict Button",
            )
            year = gr.Slider(
                1994,
                2023,
                step=1,
                label="Year",
                info="Select the year of the opinion if you manually paste the opinion below",
            )
            exc_judg = gr.CheckboxGroup(
                judges_by_year[year.value],
                value=judges_by_year[year.value],
                label="Select Justices",
                info="Select justices to consider in prediction",
            )

            opinion = gr.Textbox(
                label="Opinion", info="Paste opinion text here and press the Predict Button"
            )
        with gr.Column(scale=1):
            with gr.Row():
                clear_btn = gr.Button("Clear")
                greet_btn = gr.Button("Predict")
            op_level = gr.outputs.Label(
                num_top_classes=9, label="Predicted author of opinion"
            )

    year.release(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    year.change(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    drop.select(set_input, inputs=drop, outputs=[opinion, year, year])

    greet_btn.click(
        fn=greet,
        inputs=[opinion, exc_judg],
        outputs=[op_level],
    )

    clear_btn.click(
        fn=lambda: [None, 1994, gr.Slider(visible=True), None, None],
        outputs=[opinion, year, year, drop, op_level],
    )


if __name__ == "__main__":
    demo.launch()