Spaces:
Build error
Build error
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
# All rights reserved. | |
# See the top-level LICENSE and NOTICE files for details. | |
# LLNL-CODE-838964 | |
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
import cv2 | |
from pathlib import Path | |
import torch | |
import json | |
from detectron2.config import CfgNode as CN | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import ColorMode, Visualizer | |
from detectron2.data import MetadataCatalog | |
from detectron2.engine import DefaultPredictor | |
from pdf2image import convert_from_path | |
from PIL import Image | |
import numpy as np | |
from dit_object_detection.ditod import add_vit_config | |
import base_utils | |
from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar | |
from tokenizers.pre_tokenizers import Whitespace | |
import warnings | |
warnings.filterwarnings("ignore") | |
dit_path = Path('DiT_Extractor/dit_object_detection') | |
cfg = get_cfg() | |
add_vit_config(cfg) | |
cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml") | |
cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth" | |
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
predictor = DefaultPredictor(cfg) | |
thing_classes = ["text","title","list","table","figure"] | |
thing_map = dict(map(reversed, enumerate(thing_classes))) | |
md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) | |
md.set(thing_classes=thing_classes) | |
def get_pdf_image(pdf_file, page): | |
image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page) | |
return image | |
def get_characters(subelement): | |
all_chars = [] | |
if isinstance(subelement, LTTextLineHorizontal): | |
for char in subelement: | |
if isinstance(char, LTChar): | |
all_chars.append((char.bbox, char.get_text())) | |
if isinstance(char, LTAnno): | |
# No bbox, just a space, so make a thin slice after previous text | |
bbox = all_chars[-1][0] | |
bbox = (bbox[2],bbox[1],bbox[2],bbox[3]) | |
all_chars.append((bbox, char.get_text())) | |
return all_chars | |
def get_dit_preds(pdf, score_threshold=0.5): | |
page_count = base_utils.get_pdf_page_count(pdf) | |
# Input is numpy array of PIL image | |
page_sizes = base_utils.get_page_sizes(pdf) | |
sections = {} | |
viz_images = [] | |
page_words = base_utils.get_pdf_words(pdf) | |
for page in range(1, page_count+1): #range(2, page_count + 1): | |
image = get_pdf_image(pdf, page) | |
image = np.array(image[0]) | |
# Get prediction | |
output = predictor(image)["instances"] | |
output = output.to('cpu') | |
# Visualize predictions | |
v = Visualizer(image[:, :, ::-1], | |
md, | |
scale=1.0, | |
instance_mode=ColorMode.SEGMENTATION) | |
result = v.draw_instance_predictions(output) | |
result_image = result.get_image()[:, :, ::-1] | |
viz_img = Image.fromarray(result_image) | |
viz_images.append(viz_img) | |
words = page_words[page-1] | |
# Convert from image_size to page size | |
pdf_dimensions = page_sizes[page-1][2:] | |
# Swap height/width | |
pdf_image_size = (output.image_size[1], output.image_size[0]) | |
scale = np.array(pdf_dimensions) / np.array(pdf_image_size) | |
scale_box = np.hstack((scale,scale)) | |
# Words are in page coordinates | |
id = 0 | |
sections[page-1] = [] | |
draw = image.copy() | |
for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')): | |
if score < score_threshold: | |
continue | |
box = box_t.numpy() | |
# Flip along Y axis | |
box[1] = pdf_image_size[1] - box[1] | |
box[3] = pdf_image_size[1] - box[3] | |
# Scale | |
scaled = box * scale_box | |
# This is the correct order | |
scaled = [scaled[0], scaled[3], scaled[2], scaled[1]] | |
if clazz != thing_map['text']: | |
continue | |
start = box[0:2].tolist() | |
end = box[2:4].tolist() | |
start = [int(x) for x in start] | |
end = [int(x) for x in end] | |
out = {} | |
for word in words.copy(): | |
if base_utils.partial_overlaps(word[0:4], scaled): | |
if out == {}: | |
id += 1 | |
out['coord'] = word[0:4] | |
out['subelements'] = [] | |
out['type'] = 'content_block' | |
out['id']= id | |
out['text'] = '' | |
out['coord'] = base_utils.union(out['coord'], word[0:4]) | |
out['text'] = out['text'] + word[4].get_text() | |
characters = get_characters(word[4]) | |
out['subelements'].append(characters) | |
words.remove(word) | |
if len(out) != 0: | |
sections[page-1].append(out) | |
# Write final annotation | |
out_name = Path(pdf).name[:-4] + ".json" | |
with open(out_name, 'w', encoding='utf8') as json_out: | |
json.dump(sections, json_out, ensure_ascii=False, indent=4) | |
return viz_images | |