Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from transformers import AutoModel, AutoTokenizer | |
from PIL import Image | |
import uuid | |
# Cache the model loading function using @st.cache_resource | |
def load_model(model_name): | |
if model_name == "OCR for english or hindi (runs on CPU)": | |
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) | |
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
model.eval() # Load model on CPU | |
elif model_name == "OCR for english (runs on GPU)": | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
model.eval().cuda() # Load model on GPU | |
return tokenizer, model | |
# Function to run the GOT model for multilingual OCR | |
def run_GOT(image, tokenizer, model): | |
unique_id = str(uuid.uuid4()) | |
image_path = f"{unique_id}.png" | |
image.save(image_path) | |
st.write(f"Saved image to {image_path}") | |
try: | |
# Use the model to extract text | |
res = model.chat(tokenizer, image_path, ocr_type='ocr') # Extract plain text | |
st.write(f"Raw result: {res}") # Debug output | |
return res | |
except Exception as e: | |
st.error(f"Error: {str(e)}") # Display any errors | |
return f"Error: {str(e)}" | |
finally: | |
# Clean up the saved image | |
if os.path.exists(image_path): | |
os.remove(image_path) | |
# Function to highlight keyword in text | |
def highlight_keyword(text, keyword): | |
if keyword: | |
highlighted_text = text.replace(keyword, f"<mark>{keyword}</mark>") | |
return highlighted_text | |
return text | |
# Streamlit App | |
st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide") | |
# Creating two columns | |
left_col, right_col = st.columns(2) | |
with left_col: | |
uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) | |
with right_col: | |
# Model selection in the right column | |
model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"]) | |
if uploaded_image: | |
image = Image.open(uploaded_image) | |
with left_col: | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
with right_col: | |
if st.button("Run OCR"): | |
with st.spinner("Processing..."): | |
# Load the selected model (cached using @st.cache_resource) | |
tokenizer, model = load_model(model_option) | |
# Run OCR and cache the result using @st.cache_data | |
result_text = run_GOT(image, tokenizer, model) | |
if "Error" not in result_text: | |
# Keyword input for search | |
keyword = st.text_input("Enter a keyword to highlight") | |
# Highlight keyword in the extracted text | |
highlighted_text = highlight_keyword(result_text, keyword) | |
# Display the extracted text | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |
else: | |
st.error(result_text) | |