|
import streamlit as st |
|
from PIL import Image |
|
from pdf2image import convert_from_path |
|
from byaldi import RAGMultiModalModel |
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
import torch |
|
import time |
|
import json |
|
import re |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
|
|
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali") |
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
"Qwen/Qwen2-VL-7B-Instruct", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16 |
|
).to(device).eval() |
|
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True) |
|
|
|
return RAG, model, processor |
|
|
|
RAG, model, processor = load_models() |
|
|
|
|
|
st.title("OCR extraction") |
|
uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"]) |
|
|
|
|
|
if "extracted_text" not in st.session_state: |
|
st.session_state.extracted_text = None |
|
|
|
if uploaded_file is not None: |
|
file_type = uploaded_file.name.split('.')[-1].lower() |
|
|
|
|
|
if file_type == "pdf": |
|
st.write("Converting PDF to image...") |
|
images = convert_from_path(uploaded_file) |
|
image_to_process = images[0] |
|
else: |
|
|
|
image_to_process = Image.open(uploaded_file) |
|
|
|
|
|
st.image(image_to_process, caption="Uploaded document", use_column_width=True) |
|
|
|
|
|
unique_index_name = f"image_index_{int(time.time())}" |
|
|
|
|
|
if st.session_state.extracted_text is None: |
|
st.write(f"Indexing document with RAG (index name: {unique_index_name})...") |
|
image_path = "uploaded_image.png" |
|
image_to_process.save(image_path) |
|
|
|
RAG.index( |
|
input_path=image_path, |
|
index_name=unique_index_name, |
|
store_collection_with_index=False, |
|
overwrite=False |
|
) |
|
|
|
|
|
text_query = "Extract all english text and hindi text from the document" |
|
st.write("Searching the document using RAG...") |
|
results = RAG.search(text_query, k=1) |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": image_to_process}, |
|
{"type": "text", "text": text_query}, |
|
], |
|
} |
|
] |
|
|
|
|
|
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
|
inputs = processor( |
|
text=[text_input], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
inputs = inputs.to(device) |
|
|
|
|
|
st.write("Generating text...") |
|
generated_ids = model.generate(**inputs, max_new_tokens=100) |
|
generated_ids_trimmed = [ |
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
] |
|
|
|
output_text = processor.batch_decode( |
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
|
|
|
|
st.session_state.extracted_text = output_text[0] |
|
|
|
|
|
extracted_text = st.session_state.extracted_text |
|
structured_text = {"extracted_text": extracted_text} |
|
|
|
st.subheader("Extracted Text (JSON Format):") |
|
st.json(structured_text) |
|
|
|
|
|
if st.session_state.extracted_text: |
|
with st.form(key='search_form'): |
|
search_query = st.text_input("Enter keyword to search within the extracted text:") |
|
search_button = st.form_submit_button("Search") |
|
|
|
if search_button and search_query: |
|
|
|
extracted_text = st.session_state.extracted_text |
|
matches = re.finditer(re.escape(search_query), extracted_text, re.IGNORECASE) |
|
|
|
highlighted_text = extracted_text |
|
result = '' |
|
for match in matches: |
|
start, end = match.span() |
|
result = "**" + highlighted_text[start:end] + "**" |
|
|
|
st.subheader("Search Results:") |
|
if result == '': |
|
st.markdown('Not forund') |
|
st.markdown(result) |