File size: 2,141 Bytes
f80cc50
8ddc567
72f0cff
71645c3
72f0cff
f80cc50
71645c3
 
 
 
fb7fb6c
f80cc50
8ddc567
f80cc50
8ddc567
 
fb7fb6c
 
 
 
 
 
 
8ddc567
fb7fb6c
8ddc567
fb7fb6c
8ddc567
 
 
 
 
 
 
 
 
 
f80cc50
71645c3
fb7fb6c
71645c3
 
f80cc50
fb7fb6c
 
 
 
 
 
 
71645c3
8ddc567
 
fb7fb6c
 
 
 
 
8ddc567
 
 
fb7fb6c
71645c3
8ddc567
 
fb7fb6c
8ddc567
 
 
f80cc50
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset

pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True)
all = load_dataset("raminass/full_opinions_1994_2020")
df = pd.DataFrame(all["train"])
choices = []
for index, row in df[df.category == "per_curiam"].iterrows():
    choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]]))

max_textboxes = 100


# https://www.gradio.app/guides/controlling-layout
def greet(opinion, year):
    judges_l = (
        df[(df["year_filed"] == year) & (df["category"] != "per_curiam")]
        .author_name.unique()
        .tolist()
    )

    chunks = chunk_data(remove_citations(opinion))["text"].to_list()
    result = average_text(chunks, pipe, judges_l)
    k = len(chunks)

    wrt_boxes = []
    for i in range(k):
        wrt_boxes.append(gr.Textbox(chunks[i], visible=True))
        wrt_boxes.append(gr.Label(value=result[1][i], visible=True))
    return (
        [result[0]]
        + wrt_boxes
        + [gr.Textbox(visible=False), gr.Label(visible=False)] * (max_textboxes - k)
    )


def set_input(drop):
    return drop[0], drop[1]


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            opinion = gr.Textbox(label="Opinion")
            year = gr.Slider(1994, 2020, label="Year")
            drop = gr.Dropdown(choices=sorted(choices))
            greet_btn = gr.Button("Predict")
        op_level = gr.outputs.Label(num_top_classes=13, label="Overall")

    textboxes = []
    for i in range(max_textboxes):
        with gr.Row():
            t = gr.Textbox(f"Textbox {i}", visible=False, label=f"Paragraph {i+1} Text")
            par_level = gr.Label(
                num_top_classes=5, label=f"Paragraph {i+1} Prediction", visible=False
            )
        textboxes.append(t)
        textboxes.append(par_level)

    drop.select(set_input, inputs=drop, outputs=[opinion, year])

    greet_btn.click(
        fn=greet,
        inputs=[opinion, year],
        outputs=[op_level] + textboxes,
    )


if __name__ == "__main__":
    demo.launch()