import streamlit as st from PIL import Image from pdf2image import convert_from_path from byaldi import RAGMultiModalModel from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch import time # For generating unique index names import json import re device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize Qwen2-VL model and processor @st.cache_resource def load_models(): # Load RAG MultiModalModel and Qwen2-VL model RAG = RAGMultiModalModel.from_pretrained("vidore/colpali") model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(device).eval() processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True) return RAG, model, processor RAG, model, processor = load_models() # Step 1: Upload the file st.title("OCR extraction") uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"]) # Initialize a session state to store extracted text so it persists across reruns if "extracted_text" not in st.session_state: st.session_state.extracted_text = None if uploaded_file is not None: file_type = uploaded_file.name.split('.')[-1].lower() # Step 2: Convert PDF to image (if the input is a PDF) if file_type == "pdf": st.write("Converting PDF to image...") images = convert_from_path(uploaded_file) image_to_process = images[0] else: # For images (png/jpg), just open the image directly image_to_process = Image.open(uploaded_file) # Step 3: Display the uploaded image or PDF st.image(image_to_process, caption="Uploaded document", use_column_width=True) # Step 4: Dynamically create a unique index name using timestamp unique_index_name = f"image_index_{int(time.time())}" # Generate unique index name using current timestamp # Step 5: Perform text extraction only if it's a new file if st.session_state.extracted_text is None: st.write(f"Indexing document with RAG (index name: {unique_index_name})...") image_path = "uploaded_image.png" # Temporary save path image_to_process.save(image_path) RAG.index( input_path=image_path, index_name=unique_index_name, # Use unique index name store_collection_with_index=False, overwrite=False ) # Step 6: Perform text extraction text_query = "Extract all english text and hindi text from the document" st.write("Searching the document using RAG...") results = RAG.search(text_query, k=1) # Prepare the messages for text and image input messages = [ { "role": "user", "content": [ {"type": "image", "image": image_to_process}, {"type": "text", "text": text_query}, ], } ] # Prepare and process image and text inputs text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text_input], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(device) # Generate text output from the image using Qwen2-VL st.write("Generating text...") generated_ids = model.generate(**inputs, max_new_tokens=100) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) # Step 7: Store the extracted text in session state st.session_state.extracted_text = output_text[0] # Step 8: Display the extracted text in JSON format extracted_text = st.session_state.extracted_text structured_text = {"extracted_text": extracted_text} st.subheader("Extracted Text (JSON Format):") st.json(structured_text) # Step 9: Implement a search functionality on already extracted text if st.session_state.extracted_text: with st.form(key='search_form'): search_query = st.text_input("Enter keyword to search within the extracted text:") search_button = st.form_submit_button("Search") if search_button and search_query: # Perform case-insensitive search and highlight the matches extracted_text = st.session_state.extracted_text # Use already extracted text matches = re.finditer(re.escape(search_query), extracted_text, re.IGNORECASE) highlighted_text = extracted_text result = '' for match in matches: start, end = match.span() result = "**" + highlighted_text[start:end] + "**" st.subheader("Search Results:") if result == '': st.markdown('Not forund') st.markdown(result)