File size: 3,282 Bytes
05c2581
b5ef879
fa9edbf
b5ef879
 
fa9edbf
f07599d
2dd690c
 
 
b5ef879
2dd690c
 
d42b279
2dd690c
 
 
 
 
 
b5ef879
2dd690c
d42b279
2dd690c
f07599d
3d71449
b5ef879
2dd690c
b5ef879
f07599d
fa9edbf
b5ef879
2dd690c
fa9edbf
f604f09
fa9edbf
2dd690c
fa9edbf
3d71449
b5ef879
2dd690c
fa9edbf
 
d42b279
c1b0853
fa9edbf
2d8087a
 
 
 
 
 
 
 
 
f604f09
2d8087a
 
d42b279
77415fc
0cc2c1c
2dd690c
 
 
2d8087a
 
2dd690c
2d8087a
 
2dd690c
 
b5ef879
2d8087a
 
 
d42b279
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoModel, AutoTokenizer
import streamlit as st
from PIL import Image
import re
import os
import uuid

# Set the page layout to wide
st.set_page_config(layout="wide")

# Load the model and tokenizer only once
@st.cache_resource
def load_model(model_name):
    if model_name == "OCR for English or Hindi (CPU)":
        tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
        model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id).eval()
    else:
        tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
        model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id).eval().to('cuda')
    return model, tokenizer

if "model" not in st.session_state or "tokenizer" not in st.session_state:
    model, tokenizer = load_model("OCR for English or Hindi (CPU)")
    st.session_state.update({"model": model, "tokenizer": tokenizer})

# Function to run the GOT model for multilingual OCR
def run_ocr(image, model, tokenizer):
    image_path = f"{uuid.uuid4()}.png"
    image.save(image_path)

    try:
        res = model.chat(tokenizer, image_path, ocr_type='ocr')
        return res if isinstance(res, str) else str(res)
    except Exception as e:
        return f"Error: {str(e)}"
    finally:
        os.remove(image_path)

# Function to highlight keyword in text
def highlight_text(text, search_term):
    return re.sub(re.escape(search_term), lambda m: f'<span style="background-color: red;">{m.group()}</span>', text, flags=re.IGNORECASE) if search_term else text

# Streamlit App
st.title(":blue[Optical Character Recognition Application]")
st.write("upload image for ocr")

# Create two columns
col1, col2 = st.columns(2)

# Left column - Display the uploaded image
with col1:
    uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
    if uploaded_image:
        image = Image.open(uploaded_image)
        st.image(image, caption='Uploaded Image', use_column_width=True)

# Right column - Model selection, options, and displaying extracted text
with col2:
    model_option = st.selectbox("Select Model", ["OCR for English or Hindi (CPU)", "OCR for English (GPU)"])
    
    if st.button("Run OCR"):
        if uploaded_image:
            with st.spinner("Processing..."):
                model, tokenizer = load_model(model_option)
                result_text = run_ocr(image, model, tokenizer)
                if "Error" not in result_text:
                    st.session_state["extracted_text"] = result_text
                else:
                    st.error(result_text)
        else:
            st.error("Please upload an image before running OCR.")

    # Display the extracted text if it exists in session state
    if "extracted_text" in st.session_state:
        search_term = st.text_input("Enter a word or phrase to highlight:")
        st.subheader("Extracted Text:")
        st.markdown(f'<div style="white-space: pre-wrap;">{highlight_text(st.session_state["extracted_text"], search_term)}</div>', unsafe_allow_html=True)