ajimeno commited on
Commit
6c0128c
·
1 Parent(s): ca53d7c

Prompt option

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -8,7 +8,7 @@ from PIL import Image
8
  from io import BytesIO
9
  from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
10
 
11
- def run_prediction(sample, model, processor):
12
 
13
  pixel_values = processor(np.array(
14
  sample,
@@ -18,7 +18,7 @@ def run_prediction(sample, model, processor):
18
  with torch.no_grad():
19
  outputs = model.generate(
20
  pixel_values.to(device),
21
- decoder_input_ids=processor.tokenizer("<s><s_plain>", add_special_tokens=False, return_tensors="pt").input_ids.to(device),
22
  do_sample=True,
23
  top_p=0.92,
24
  top_k=5,
@@ -52,7 +52,9 @@ with st.sidebar:
52
  if uploaded_file is not None:
53
  # To read file as bytes:
54
  image_bytes_data = uploaded_file.getvalue()
55
- image_upload = Image.open(BytesIO(image_bytes_data))
 
 
56
 
57
  if image_upload:
58
  image = image_upload
@@ -87,6 +89,6 @@ with st.spinner(f'Processing the document ...'):
87
  model.to(device)
88
 
89
  st.info(f'Parsing document')
90
- parsed_info = run_prediction(image.convert("RGB"), model, processor)
91
  st.text(f'\nDocument:')
92
  st.text_area('Output text', value=parsed_info, height=800)
 
8
  from io import BytesIO
9
  from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
10
 
11
+ def run_prediction(sample, model, processor, prompt):
12
 
13
  pixel_values = processor(np.array(
14
  sample,
 
18
  with torch.no_grad():
19
  outputs = model.generate(
20
  pixel_values.to(device),
21
+ decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
22
  do_sample=True,
23
  top_p=0.92,
24
  top_k=5,
 
52
  if uploaded_file is not None:
53
  # To read file as bytes:
54
  image_bytes_data = uploaded_file.getvalue()
55
+ image_upload = Image.open(BytesIO(image_bytes_data))
56
+
57
+ prompt = st.selectbox('Prompt', ('<s><s_pretraining>', '<s><s_plain>', '<s><s_hierarchical>'), index=2)
58
 
59
  if image_upload:
60
  image = image_upload
 
89
  model.to(device)
90
 
91
  st.info(f'Parsing document')
92
+ parsed_info = run_prediction(image.convert("RGB"), model, processor, prompt)
93
  st.text(f'\nDocument:')
94
  st.text_area('Output text', value=parsed_info, height=800)