File size: 3,556 Bytes
6c4a874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

import io
import pandas as pd
import plotly_express as px
import streamlit as st
import torch
import torch.nn.functional as F
import numpy as np
from easyocr import Reader
from PIL import Image
from transformers import (
    LayoutLMv3ImageProcessor,
    LayoutLMv3ForSequenceClassification,
    LayoutLMv3Processor,
    LayoutLMv3TokenizerFast,
)


DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"

def create_bounding_box(bbox_data, width_scale: float, height_scale: float):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)
        
    left = int(min(xs) * width_scale)
    top = int(min(ys) * height_scale)
    right = int(max(xs) * width_scale)
    bottom = int(max(ys) * height_scale)
    
    return [left, top, right, bottom]

@st.cache_resource
def create_ocr_reader():
    return Reader(["en"])

@st.cache_resource
def create_processor():
    feature_extractor = LayoutLMv3ImageProcessor(apply_ocr = False)
    tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
    return LayoutLMv3Processor(feature_extractor, tokenizer)

@st.cache_resource
def create_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
    return model.eval().to(DEVICE)

def predict(image: Image, reader: Reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
    ocr_result = reader.readtext(image)
    
    width, height = image.size
    width_scale = 1000 / width
    height_scale = 1000 / height
    
    words = []
    boxes = []
    
    for bbox, word, confidence in ocr_result:
        words.append(word)
        boxes.append(create_bounding_box(bbox, width_scale, height_scale))
        
    encoding = processor(image, words, boxes = boxes, max_length=512, padding = "max_length", truncation = True, return_tensors = "pt")
        
    with torch.inference_mode():
        output = model(input_ids = encoding["input_ids"].to(DEVICE),
        attention_mask = encoding["attention_mask"].to(DEVICE),
        bbox = encoding["bbox"].to(DEVICE),
        pixel_values = encoding["pixel_values"].to(DEVICE))
        
    logits = output.logits
    predicted_class = logits.argmax()
    probabilities = F.softmax(logits, dim=-1).flatten().tolist()
    
    return predicted_class.cpu().item(), probabilities

reader = create_ocr_reader()
processor = create_processor()
model = create_model()

upload_file = st.file_uploader("Upload Document Image", ["jpg", "png"])
if upload_file is not None:
    bytes_data = io.BytesIO(upload_file.getvalue())
    image = Image.open(bytes_data)
    st.image(image, "Your Document Image")
    
    predicted_class, probabilities = predict(image, reader, processor, model)
    print("Predicted class:",predicted_class)
    print("Probabilities:",probabilities)
    # print(predict(image, reader, processor, model))
    
    predicted_label = model.config.id2label[predicted_class]
    
    st.markdown(f"Predicted document type: **{predicted_label}**")
    
    # make chart
    df_predictions = pd.DataFrame({
        "Document" : list(model.config.id2label.values()),
        "confidence" : probabilities
    })
    
    fig = px.bar(df_predictions, x = "Document", y = "confidence", title = "Document Type Confidence")
    st.plotly_chart(fig, use_container_width=True)