Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from utils import * | |
from datasets import load_dataset | |
pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True) | |
all = load_dataset("raminass/full_opinions_1994_2020") | |
df = pd.DataFrame(all["train"]) | |
choices = [] | |
for index, row in df[df.category == "per_curiam"].iterrows(): | |
if len(row["text"]) > 1000: | |
choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]])) | |
unique_judges_by_year = ( | |
df[df.author_name != "per_curiam"].groupby("year_filed")["author_name"].unique() | |
) | |
additional_judges = ["Justice Breyer", "Justice Kennedy"] | |
unique_judges_by_year[1994] = list(unique_judges_by_year[1994]) + additional_judges | |
# https://www.gradio.app/guides/controlling-layout | |
def greet(opinion, judges_l): | |
chunks = chunk_data(remove_citations(opinion))["text"].to_list() | |
result = average_text(chunks, pipe, judges_l) | |
return result[0] | |
def set_input(drop): | |
return drop[0], drop[1], gr.Slider(visible=True) | |
def update_year(year): | |
return gr.CheckboxGroup( | |
unique_judges_by_year[year].tolist(), | |
value=unique_judges_by_year[year].tolist(), | |
label="Select Judges", | |
) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
drop = gr.Dropdown( | |
choices=sorted(choices), | |
label="Per Curiam Opinions", | |
info="Select a per curiam opinion to use as input", | |
) | |
year = gr.Slider( | |
1994, | |
2020, | |
step=1, | |
label="Year", | |
info="Select the year of the opinion if you manually pass the opinion below", | |
) | |
exc_judg = gr.CheckboxGroup( | |
unique_judges_by_year[year.value], | |
value=unique_judges_by_year[year.value], | |
label="Select Judges", | |
info="Select judges to consider in prediction", | |
) | |
opinion = gr.Textbox( | |
label="Opinion", info="Paste opinion text here or select from dropdown" | |
) | |
with gr.Column(): | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
greet_btn = gr.Button("Predict") | |
op_level = gr.outputs.Label( | |
num_top_classes=9, label="Predicted author of opinion" | |
) | |
year.release( | |
update_year, | |
inputs=[year], | |
outputs=[exc_judg], | |
) | |
year.change( | |
update_year, | |
inputs=[year], | |
outputs=[exc_judg], | |
) | |
drop.select(set_input, inputs=drop, outputs=[opinion, year, year]) | |
greet_btn.click( | |
fn=greet, | |
inputs=[opinion, exc_judg], | |
outputs=[op_level], | |
) | |
clear_btn.click( | |
fn=lambda: [None, 1994, gr.Slider(visible=True), None, None], | |
outputs=[opinion, year, year, drop, op_level], | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |