File size: 2,539 Bytes
f80cc50
8ddc567
72f0cff
71645c3
72f0cff
f80cc50
71645c3
 
 
 
926ab2a
 
f80cc50
8ddc567
f80cc50
8ddc567
 
fb7fb6c
 
 
 
 
 
 
8ddc567
fb7fb6c
8ddc567
fb7fb6c
8ddc567
 
 
 
 
 
 
 
 
 
f80cc50
71645c3
926ab2a
71645c3
 
f80cc50
fb7fb6c
 
 
926ab2a
fb7fb6c
926ab2a
 
 
fb7fb6c
71645c3
8ddc567
 
fb7fb6c
 
 
 
 
8ddc567
 
 
926ab2a
71645c3
8ddc567
 
fb7fb6c
8ddc567
 
 
926ab2a
 
 
 
 
 
f80cc50
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset

pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True)
all = load_dataset("raminass/full_opinions_1994_2020")
df = pd.DataFrame(all["train"])
choices = []
for index, row in df[df.category == "per_curiam"].iterrows():
    if len(row["text"]) > 1000:
        choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]]))

max_textboxes = 100


# https://www.gradio.app/guides/controlling-layout
def greet(opinion, year):
    judges_l = (
        df[(df["year_filed"] == year) & (df["category"] != "per_curiam")]
        .author_name.unique()
        .tolist()
    )

    chunks = chunk_data(remove_citations(opinion))["text"].to_list()
    result = average_text(chunks, pipe, judges_l)
    k = len(chunks)

    wrt_boxes = []
    for i in range(k):
        wrt_boxes.append(gr.Textbox(chunks[i], visible=True))
        wrt_boxes.append(gr.Label(value=result[1][i], visible=True))
    return (
        [result[0]]
        + wrt_boxes
        + [gr.Textbox(visible=False), gr.Label(visible=False)] * (max_textboxes - k)
    )


def set_input(drop):
    return drop[0], drop[1], gr.Slider(visible=False)


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            opinion = gr.Textbox(label="Opinion")
            year = gr.Slider(1994, 2020, step=1, label="Year")
            drop = gr.Dropdown(choices=sorted(choices))
            with gr.Row():
                clear_btn = gr.Button("Clear")
                greet_btn = gr.Button("Predict")
        op_level = gr.outputs.Label(num_top_classes=13, label="Overall")

    textboxes = []
    for i in range(max_textboxes):
        with gr.Row():
            t = gr.Textbox(f"Textbox {i}", visible=False, label=f"Paragraph {i+1} Text")
            par_level = gr.Label(
                num_top_classes=5, label=f"Paragraph {i+1} Prediction", visible=False
            )
        textboxes.append(t)
        textboxes.append(par_level)

    drop.select(set_input, inputs=drop, outputs=[opinion, year, year])

    greet_btn.click(
        fn=greet,
        inputs=[opinion, year],
        outputs=[op_level] + textboxes,
    )

    clear_btn.click(
        fn=lambda: [None, 1994, gr.Slider(visible=True), None, None]
        + [gr.Textbox(visible=False), gr.Label(visible=False)] * max_textboxes,
        outputs=[opinion, year, year, drop, op_level] + textboxes,
    )


if __name__ == "__main__":
    demo.launch()