OCR_Application / app.py
Divyansh12's picture
Update app.py
f7e10cd verified
raw
history blame
3.46 kB
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import uuid
# Cache the model loading function using @st.cache_resource
@st.cache_resource
def load_model(model_name):
if model_name == "OCR for english or hindi (runs on CPU)":
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval() # Load model on CPU
elif model_name == "OCR for english (runs on GPU)":
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval().cuda() # Load model on GPU
return tokenizer, model
# Function to run the GOT model for multilingual OCR
@st.cache_data
def run_GOT(image, tokenizer, model):
unique_id = str(uuid.uuid4())
image_path = f"{unique_id}.png"
image.save(image_path)
st.write(f"Saved image to {image_path}")
try:
# Use the model to extract text
res = model.chat(tokenizer, image_path, ocr_type='ocr') # Extract plain text
st.write(f"Raw result: {res}") # Debug output
return res
except Exception as e:
st.error(f"Error: {str(e)}") # Display any errors
return f"Error: {str(e)}"
finally:
# Clean up the saved image
if os.path.exists(image_path):
os.remove(image_path)
# Function to highlight keyword in text
def highlight_keyword(text, keyword):
if keyword:
highlighted_text = text.replace(keyword, f"<mark>{keyword}</mark>")
return highlighted_text
return text
# Streamlit App
st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide")
# Creating two columns
left_col, right_col = st.columns(2)
with left_col:
uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
with right_col:
# Model selection in the right column
model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"])
if uploaded_image:
image = Image.open(uploaded_image)
with left_col:
st.image(image, caption='Uploaded Image', use_column_width=True)
with right_col:
if st.button("Run OCR"):
with st.spinner("Processing..."):
# Load the selected model (cached using @st.cache_resource)
tokenizer, model = load_model(model_option)
# Run OCR and cache the result using @st.cache_data
result_text = run_GOT(image, tokenizer, model)
if "Error" not in result_text:
# Keyword input for search
keyword = st.text_input("Enter a keyword to highlight")
# Highlight keyword in the extracted text
highlighted_text = highlight_keyword(result_text, keyword)
# Display the extracted text
st.markdown(highlighted_text, unsafe_allow_html=True)
else:
st.error(result_text)