Prompt option
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from PIL import Image
|
|
8 |
from io import BytesIO
|
9 |
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
|
10 |
|
11 |
-
def run_prediction(sample, model, processor):
|
12 |
|
13 |
pixel_values = processor(np.array(
|
14 |
sample,
|
@@ -18,7 +18,7 @@ def run_prediction(sample, model, processor):
|
|
18 |
with torch.no_grad():
|
19 |
outputs = model.generate(
|
20 |
pixel_values.to(device),
|
21 |
-
decoder_input_ids=processor.tokenizer(
|
22 |
do_sample=True,
|
23 |
top_p=0.92,
|
24 |
top_k=5,
|
@@ -52,7 +52,9 @@ with st.sidebar:
|
|
52 |
if uploaded_file is not None:
|
53 |
# To read file as bytes:
|
54 |
image_bytes_data = uploaded_file.getvalue()
|
55 |
-
image_upload = Image.open(BytesIO(image_bytes_data))
|
|
|
|
|
56 |
|
57 |
if image_upload:
|
58 |
image = image_upload
|
@@ -87,6 +89,6 @@ with st.spinner(f'Processing the document ...'):
|
|
87 |
model.to(device)
|
88 |
|
89 |
st.info(f'Parsing document')
|
90 |
-
parsed_info = run_prediction(image.convert("RGB"), model, processor)
|
91 |
st.text(f'\nDocument:')
|
92 |
st.text_area('Output text', value=parsed_info, height=800)
|
|
|
8 |
from io import BytesIO
|
9 |
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
|
10 |
|
11 |
+
def run_prediction(sample, model, processor, prompt):
|
12 |
|
13 |
pixel_values = processor(np.array(
|
14 |
sample,
|
|
|
18 |
with torch.no_grad():
|
19 |
outputs = model.generate(
|
20 |
pixel_values.to(device),
|
21 |
+
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
|
22 |
do_sample=True,
|
23 |
top_p=0.92,
|
24 |
top_k=5,
|
|
|
52 |
if uploaded_file is not None:
|
53 |
# To read file as bytes:
|
54 |
image_bytes_data = uploaded_file.getvalue()
|
55 |
+
image_upload = Image.open(BytesIO(image_bytes_data))
|
56 |
+
|
57 |
+
prompt = st.selectbox('Prompt', ('<s><s_pretraining>', '<s><s_plain>', '<s><s_hierarchical>'), index=2)
|
58 |
|
59 |
if image_upload:
|
60 |
image = image_upload
|
|
|
89 |
model.to(device)
|
90 |
|
91 |
st.info(f'Parsing document')
|
92 |
+
parsed_info = run_prediction(image.convert("RGB"), model, processor, prompt)
|
93 |
st.text(f'\nDocument:')
|
94 |
st.text_area('Output text', value=parsed_info, height=800)
|