Spaces:
Running
Running
File size: 2,264 Bytes
3ce6ec3 4159b3a 3ce6ec3 4159b3a 08dbced 4159b3a 3ce6ec3 08dbced 4159b3a 3ce6ec3 4159b3a 3ce6ec3 70b9cee 3ce6ec3 4159b3a 005856c 2bb3ce5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained("to-be/ID_cards_v1")
model = VisionEncoderDecoderModel.from_pretrained("to-be/ID_cards_v1")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)
description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on CORD (document parsing). To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(
fn=process_document,
inputs="image",
outputs="json",
title="Demo: Donut 🍩 for Document Parsing",
description=description,
article=article,
examples=[["3001.jpg"]],
cache_examples=False)
demo.queue(max_size=20,api_open=True)
demo.launch(share=True) |