File size: 4,311 Bytes
1689fdb
 
 
 
 
 
 
b2c2fa2
 
 
87c129a
1689fdb
52ca21c
 
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c02759
1689fdb
 
 
 
 
 
 
 
 
32d3bb8
1689fdb
 
 
1869205
ec7e02b
1869205
1689fdb
 
6c02759
c615442
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
32d3bb8
1689fdb
 
 
 
7c259fa
b2c2fa2
 
c615442
ab85ffe
c615442
b2c2fa2
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372194e
 
 
1689fdb
 
 
 
 
 
 
b2c2fa2
 
1689fdb
b2c2fa2
5751f79
 
 
1689fdb
c615442
1689fdb
 
b2c2fa2
1689fdb
 
c615442
1689fdb
32e270a
c615442
 
228599b
32e270a
 
 
 
1689fdb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from huggingface_hub import hf_hub_download
import re
from PIL import Image
import requests
from nougat.dataset.rasterize import rasterize_paper
from transformers import NougatProcessor, VisionEncoderDecoderModel
import torch
import gradio as gr
import uuid
import os
import spaces

processor = NougatProcessor.from_pretrained("facebook/nougat-small")
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-small")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device) 


def get_pdf(pdf_link):
  unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf"

  response = requests.get(pdf_link)

  if response.status_code == 200:
      with open(unique_filename, 'wb') as pdf_file:
          pdf_file.write(response.content)
      print("PDF downloaded successfully.")
  else:
      print("Failed to download the PDF.")
  return unique_filename


@spaces.GPU
def predict(image):
  # prepare PDF image for the model
  image = Image.open(image)
  pixel_values = processor(image, return_tensors="pt").pixel_values

  # generate transcription (here we only generate 30 tokens)
  outputs = model.generate(
      pixel_values.to(device),
      min_length=1,
      max_new_tokens=1500,
      bad_words_ids=[[processor.tokenizer.unk_token_id]],
  )

  page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
  page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
  return page_sequence



def inference(pdf_file, pdf_link):
  if pdf_file is None:
    if pdf_link == '':
      print("No file is uploaded and No link is provided")
      return "No data provided. Upload a pdf file or provide a pdf link and try again!"
    else:
      file_name = get_pdf(pdf_link)
  else:
    file_name = pdf_file.name
    pdf_name = pdf_file.name.split('/')[-1].split('.')[0]

  images = rasterize_paper(file_name, return_pil=True)
  sequence = ""
  #ย infer for every page and concat
  for image in images:
    sequence += predict(image)


  content = sequence.replace(r'\(', '$').replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
  with open(f"{os.getcwd()}/output.md","w+") as f:
      f.write(content)
      f.close()

      
  return content, f"{os.getcwd()}/output.md"


css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents ๐Ÿซ<center><h1>")
  gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
  gr.HTML("<h3><center>This demo is based on transformers implementation of Nougat ๐Ÿค—<center><h3>")


  with gr.Row():
    mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>')
    mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>')
    mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>')
  
  
  with gr.Row(equal_height=True):
    pdf_file = gr.File(label='PDF ๐Ÿ“‘', file_count='single', scale=1)
    pdf_link = gr.Textbox(placeholder='Enter an arxiv link here', label='Link to Paper๐Ÿ”—', scale=1)
  with gr.Row():
    btn = gr.Button('Run Nougat ๐Ÿซ')
  with gr.Row():
    clr = gr.Button('Clear Inputs & Outputs ๐Ÿงผ')

  output_headline = gr.Markdown("## PDF converted to markup language through Nougat-OCR๐Ÿ‘‡")
  with gr.Row():
      parsed_output = gr.Markdown(elem_id='mkd', value='Output Text ๐Ÿ“')
      output_file = gr.File(file_types = ["txt"], label="Output File ๐Ÿ“‘")
  
  btn.click(inference, [pdf_file, pdf_link], [parsed_output, output_file])
  clr.click(lambda : (gr.update(value=None), 
                      gr.update(value=None),
                      gr.update(value=None), 
                      gr.update(value=None)), 
             [], 
             [pdf_file, pdf_link, parsed_output, output_file]
            )
  gr.Examples(
      [["nougat.pdf", ""], [None, "https://arxiv.org/pdf/2308.08316.pdf"]],
      inputs = [pdf_file, pdf_link],
      outputs = [parsed_output, output_file],
      fn=inference,
      cache_examples=True,
      label='Click on any Examples below to get Nougat OCR results quickly:'
  )


    
demo.queue()
demo.launch(debug=True)