import os import streamlit as st from transformers import AutoModel, AutoTokenizer from PIL import Image import uuid # Cache the model loading function using @st.cache_resource @st.cache_resource def load_model(model_name): if model_name == "OCR for english or hindi (runs on CPU)": tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) model.eval() # Load model on CPU elif model_name == "OCR for english (runs on GPU)": tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) model.eval().cuda() # Load model on GPU return tokenizer, model # Function to run the GOT model for multilingual OCR @st.cache_data def run_GOT(image, tokenizer, model): unique_id = str(uuid.uuid4()) image_path = f"{unique_id}.png" image.save(image_path) st.write(f"Saved image to {image_path}") try: # Use the model to extract text res = model.chat(tokenizer, image_path, ocr_type='ocr') # Extract plain text st.write(f"Raw result: {res}") # Debug output return res except Exception as e: st.error(f"Error: {str(e)}") # Display any errors return f"Error: {str(e)}" finally: # Clean up the saved image if os.path.exists(image_path): os.remove(image_path) # Function to highlight keyword in text def highlight_keyword(text, keyword): if keyword: highlighted_text = text.replace(keyword, f"{keyword}") return highlighted_text return text # Streamlit App st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide") # Creating two columns left_col, right_col = st.columns(2) with left_col: uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) with right_col: # Model selection in the right column model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"]) if uploaded_image: image = Image.open(uploaded_image) with left_col: st.image(image, caption='Uploaded Image', use_column_width=True) with right_col: if st.button("Run OCR"): with st.spinner("Processing..."): # Load the selected model (cached using @st.cache_resource) tokenizer, model = load_model(model_option) # Run OCR and cache the result using @st.cache_data result_text = run_GOT(image, tokenizer, model) if "Error" not in result_text: # Keyword input for search keyword = st.text_input("Enter a keyword to highlight") # Highlight keyword in the extracted text highlighted_text = highlight_keyword(result_text, keyword) # Display the extracted text st.markdown(highlighted_text, unsafe_allow_html=True) else: st.error(result_text)