Yiwahpsp's picture
Start app v.1
6c4a874 verified
raw
history blame
3.56 kB
import io
import pandas as pd
import plotly_express as px
import streamlit as st
import torch
import torch.nn.functional as F
import numpy as np
from easyocr import Reader
from PIL import Image
from transformers import (
LayoutLMv3ImageProcessor,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Processor,
LayoutLMv3TokenizerFast,
)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"
def create_bounding_box(bbox_data, width_scale: float, height_scale: float):
xs = []
ys = []
for x, y in bbox_data:
xs.append(x)
ys.append(y)
left = int(min(xs) * width_scale)
top = int(min(ys) * height_scale)
right = int(max(xs) * width_scale)
bottom = int(max(ys) * height_scale)
return [left, top, right, bottom]
@st.cache_resource
def create_ocr_reader():
return Reader(["en"])
@st.cache_resource
def create_processor():
feature_extractor = LayoutLMv3ImageProcessor(apply_ocr = False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
return LayoutLMv3Processor(feature_extractor, tokenizer)
@st.cache_resource
def create_model():
model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
return model.eval().to(DEVICE)
def predict(image: Image, reader: Reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
ocr_result = reader.readtext(image)
width, height = image.size
width_scale = 1000 / width
height_scale = 1000 / height
words = []
boxes = []
for bbox, word, confidence in ocr_result:
words.append(word)
boxes.append(create_bounding_box(bbox, width_scale, height_scale))
encoding = processor(image, words, boxes = boxes, max_length=512, padding = "max_length", truncation = True, return_tensors = "pt")
with torch.inference_mode():
output = model(input_ids = encoding["input_ids"].to(DEVICE),
attention_mask = encoding["attention_mask"].to(DEVICE),
bbox = encoding["bbox"].to(DEVICE),
pixel_values = encoding["pixel_values"].to(DEVICE))
logits = output.logits
predicted_class = logits.argmax()
probabilities = F.softmax(logits, dim=-1).flatten().tolist()
return predicted_class.cpu().item(), probabilities
reader = create_ocr_reader()
processor = create_processor()
model = create_model()
upload_file = st.file_uploader("Upload Document Image", ["jpg", "png"])
if upload_file is not None:
bytes_data = io.BytesIO(upload_file.getvalue())
image = Image.open(bytes_data)
st.image(image, "Your Document Image")
predicted_class, probabilities = predict(image, reader, processor, model)
print("Predicted class:",predicted_class)
print("Probabilities:",probabilities)
# print(predict(image, reader, processor, model))
predicted_label = model.config.id2label[predicted_class]
st.markdown(f"Predicted document type: **{predicted_label}**")
# make chart
df_predictions = pd.DataFrame({
"Document" : list(model.config.id2label.values()),
"confidence" : probabilities
})
fig = px.bar(df_predictions, x = "Document", y = "confidence", title = "Document Type Confidence")
st.plotly_chart(fig, use_container_width=True)