File size: 5,283 Bytes
d89098d
 
 
 
64c717a
d89098d
64c717a
 
 
 
d89098d
 
 
54f6682
64c717a
d89098d
 
 
 
 
 
 
 
 
 
 
 
64c717a
 
 
 
 
 
 
 
513a0c5
64c717a
 
 
411008f
64c717a
 
 
 
 
 
 
 
 
411008f
513a0c5
 
 
411008f
64c717a
 
 
513a0c5
 
 
 
 
64c717a
 
 
 
 
 
 
 
513a0c5
 
 
64c717a
 
 
 
513a0c5
 
 
 
 
64c717a
 
d89098d
 
 
 
 
 
 
513a0c5
 
 
54f6682
 
 
 
 
d89098d
 
 
 
 
 
 
 
 
 
 
 
411008f
513a0c5
 
 
 
 
 
 
411008f
513a0c5
64c717a
d89098d
64c717a
d89098d
64c717a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import tempfile

import fitz
import gradio as gr
import PIL
import skimage
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from huggingface_hub import hf_hub_download
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage

checkpoint_path = "./2022-01-15-vfnet-post-self-train.pth"
checkpoint_and_model = model_from_checkpoint(checkpoint_path)
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]

img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
    [*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)


learn = load_learner(
    hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl")
)

labels = learn.dls.vocab


def predict(pdf, confidence, generate_file):
    filename_without_extension = pdf.name[:-4]
    document = fitz.open(pdf.name)
    results = []
    images = []
    tmp_dir = tempfile.gettempdir()
    for page_num, page in enumerate(document, start=1):
        image_pixmap = page.get_pixmap()
        image = image_pixmap.tobytes()
        _, _, probs = learn.predict(image)
        results.append(
            {labels[i]: float(probs[i]) for i in range(len(labels))}
        )
        if probs[0] > (confidence / 100):
            redaction_count = len(images)
            image_pixmap.save(
                os.path.join(
                    tmp_dir, filename_without_extension, f"page-{page_num}.png"
                )
            )
            images.append(
                [
                    f"Redacted page #{redaction_count + 1} on page {page_num}",
                    os.path.join(
                        tmp_dir,
                        filename_without_extension,
                        f"page-{page_num}.png",
                    ),
                ]
            )

    redacted_pages = [
        str(page + 1)
        for page in range(len(results))
        if results[page]["redacted"] > (confidence / 100)
    ]
    report = os.path.join(
        tmp_dir, filename_without_extension, "redacted_pages.pdf"
    )
    if generate_file:
        pdf = FPDF()
        pdf.set_auto_page_break(0)
        imagelist = sorted(
            [
                i
                for i in os.listdir(tmp_dir, filename_without_extension)
                if i.endswith("png")
            ]
        )
        for image in imagelist:
            # with PILImage.open(os.path.join(tmp_dir, image)) as img:
            #     size = img.size
            #     if size[0] > size[1]:
            #         pdf.add_page("L")
            #     else:
            #         pdf.add_page("P")
            # pdf.image(os.path.join(tmp_dir, image))
            with PILImage.open(
                os.path.join(tmp_dir, filename_without_extension, image)
            ) as img:
                size = img.size
                if size[0] > size[1]:
                    pdf.add_page("L")
                else:
                    pdf.add_page("P")
                pred_dict = model_type.end2end_detect(
                    img,
                    valid_tfms,
                    model,
                    class_map=class_map,
                    detection_threshold=0.7,
                    display_label=True,
                    display_bbox=True,
                    return_img=True,
                    font_size=16,
                    label_color="#FF59D6",
                )
                pred_dict["img"].save(
                    os.path.join(
                        tmp_dir, filename_without_extension, f"pred-{image}"
                    )
                )
            pdf.image(
                os.path.join(
                    tmp_dir, filename_without_extension, f"pred-{image}"
                )
            )
        pdf.output(report, "F")

    text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\n The redacted page numbers were: {', '.join(redacted_pages)}."

    if generate_file:
        return text_output, images, report
    else:
        return text_output, images, None


title = "Redaction Detector"

description = "A classifier trained on publicly released redacted (and unredacted) FOIA documents, using [fastai](https://github.com/fastai/fastai)."

with open("article.md") as f:
    article = f.read()

examples = [["test1.pdf", 80, False], ["test2.pdf", 80, False]]
interpretation = "default"
enable_queue = True
theme = "grass"
allow_flagging = "never"

demo = gr.Interface(
    fn=predict,
    inputs=[
        "file",
        gr.inputs.Slider(
            minimum=0,
            maximum=100,
            step=None,
            default=80,
            label="Confidence",
            optional=False,
        ),
        "checkbox",
    ],
    outputs=[
        gr.outputs.Textbox(label="Document Analysis"),
        gr.outputs.Carousel(["text", "image"], label="Redacted pages"),
        gr.outputs.File(label="Download redacted pages"),
    ],
    title=title,
    description=description,
    article=article,
    theme=theme,
    allow_flagging=allow_flagging,
    examples=examples,
    interpretation=interpretation,
)

demo.launch(
    cache_examples=True,
    enable_queue=enable_queue,
)