QalamV0.2 / app.py
gagan3012's picture
Update app.py
be911f6 verified
import streamlit as st
from streamlit_cropper import st_cropper
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor, NougatProcessor
import torch
import re
import pytesseract
from io import BytesIO
import openai
import requests
from nougat.dataset.rasterize import rasterize_paper
import uuid
import os
def get_pdf(pdf_link):
unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf"
response = requests.get(pdf_link)
if response.status_code == 200:
with open(unique_filename, 'wb') as pdf_file:
pdf_file.write(response.content)
print("PDF downloaded successfully.")
else:
print("Failed to download the PDF.")
return unique_filename
def predict_arabic(img, model_name="UBC-NLP/Qalam"):
# if img is None:
# _,generated_text=main(image)
# return generated_text
# else:
# model_name = "UBC-NLP/Qalam"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
images = img.convert("RGB")
pixel_values = processor(images, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values, max_length=256)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)[0]
return generated_text
def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"):
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
image = img.convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence).strip()
return sequence
def predict_nougat(img, model_name="facebook/nougat-small"):
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = NougatProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
image = img.convert("RGB")
pixel_values = processor(image, return_tensors="pt",
data_format="channels_first").pixel_values
# generate transcription (here we only generate 30 tokens)
outputs = model.generate(
pixel_values.to(device),
min_length=1,
max_new_tokens=1500,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
)
page_sequence = processor.batch_decode(
outputs, skip_special_tokens=True)[0]
# page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
return page_sequence
def inference_nougat(pdf_file, pdf_link):
if pdf_file is None:
if pdf_link == '':
print("No file is uploaded and No link is provided")
return "No data provided. Upload a pdf file or provide a pdf link and try again!"
else:
file_name = get_pdf(pdf_link)
else:
file_name = pdf_file.name
pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
images = rasterize_paper(file_name, return_pil=True)
sequence = ""
# infer for every page and concat
for image in images:
sequence += predict_nougat(image)
content = sequence.replace(r'\(', '$').replace(
r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
return content
def predict_tesseract(img):
text = pytesseract.image_to_string(Image.open(img))
return text
st.set_option('deprecation.showfileUploaderEncoding', False)
st.set_page_config(
page_title="Ex-stream-ly Cool App",
page_icon="🖊️",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is a header. This is an *extremely* cool app!"
}
)
# Upload an image and set some options for demo purposes
st.header("Qalam: A Multilingual OCR System")
st.sidebar.header("Configuration and Image Upload")
st.sidebar.subheader("Adjust Image Enhancement Options")
img_file = st.sidebar.file_uploader(
label='Upload a file', type=['png', 'jpg', "pdf"])
# input_file = st.sidebar.text_input("Enter the file URL")
realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True)
# box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF')
aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[
"Free"])
aspect_dict = {
"Free": None
}
aspect_ratio = aspect_dict[aspect_choice]
st.sidebar.subheader("Select OCR Language and Model")
Lng = st.sidebar.selectbox(label="Language", options=[
"Arabic", "English", "French", "Korean", "Chinese"])
Models = {
"Arabic": "Qalam",
"English": "Nougat",
"French": "Tesseract",
"Korean": "Donut",
"Chinese": "Donut"
}
st.sidebar.markdown(f"### Selected Model: {Models[Lng]}")
if img_file:
if not img_file.type == "application/pdf":
img = Image.open(img_file)
if not realtime_update:
st.write("Double click to save crop")
col1, col2 = st.columns(2)
with col1:
st.subheader("Input: Upload and Crop Your Image")
# Get a cropped image from the frontend
cropped_img = st_cropper(
img,
realtime_update=realtime_update,
box_color="#FF0000",
aspect_ratio=aspect_ratio,
should_resize_image=True,
)
with col2:
# Manipulate cropped image at will
st.subheader("Output: Preview and Analyze")
# _ = cropped_img.thumbnail((150, 150))
st.image(cropped_img)
button = st.button("Run OCR")
if button:
with st.spinner('Running OCR...'):
if Lng == "Arabic":
ocr_text = predict_arabic(cropped_img)
elif Lng == "English":
ocr_text = predict_nougat(cropped_img)
elif Lng == "French":
ocr_text = predict_tesseract(cropped_img)
elif Lng == "Korean":
ocr_text = predict_english(cropped_img)
elif Lng == "Chinese":
ocr_text = predict_english(cropped_img)
st.subheader(f"OCR Results for {Lng}")
st.write(ocr_text)
text_file = BytesIO(ocr_text.encode())
st.download_button('Download Text', text_file,
file_name='ocr_text.txt')
elif img_file.type == "application/pdf":
button = st.sidebar.button("Run OCR")
if button:
with st.spinner('Running OCR...'):
ocr_text = inference_nougat(img_file, "")
st.subheader(f"OCR Results for the PDF file")
st.write(ocr_text)
text_file = BytesIO(ocr_text.encode())
st.download_button('Download Text', text_file,
file_name='ocr_text.txt')
# openai.api_key = ""
# if "openai_model" not in st.session_state:
# st.session_state["openai_model"] = "gpt-3.5-turbo"
# if "messages" not in st.session_state:
# st.session_state.messages = []
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# if prompt := st.chat_input("How can I help?"):
# st.session_state.messages.append({"role": "user", "content": ocr_text + prompt})
# with st.chat_message("user"):
# st.markdown(prompt)
# with st.chat_message("assistant"):
# message_placeholder = st.empty()
# full_response = ""
# for response in openai.ChatCompletion.create(
# model=st.session_state["openai_model"],
# messages=[
# {"role": m["role"], "content": m["content"]}
# for m in st.session_state.messages
# ],
# stream=True,
# ):
# full_response += response.choices[0].delta.get("content", "")
# message_placeholder.markdown(full_response + "▌")
# message_placeholder.markdown(full_response)
# st.session_state.messages.append({"role": "assistant", "content": full_response})