Divyansh12 commited on
Commit
b5ef879
·
verified ·
1 Parent(s): 06672c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -57
app.py CHANGED
@@ -1,34 +1,48 @@
1
- import os
2
- import streamlit as st
3
  from transformers import AutoModel, AutoTokenizer
 
4
  from PIL import Image
 
 
5
  import uuid
6
 
7
- # Cache the model loading function using @st.cache_resource
8
- @st.cache_resource
9
- def load_model(model_name):
10
- if model_name == "OCR for english or hindi (runs on CPU)":
11
- tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
12
- model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
13
- model.eval() # Load model on CPU
14
- elif model_name == "OCR for english (runs on GPU)":
15
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
16
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
17
- model.eval().cuda() # Load model on GPU
18
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Function to run the GOT model for multilingual OCR
21
- @st.cache_data
22
- def run_GOT(_image, _tokenizer, _model):
23
  unique_id = str(uuid.uuid4())
24
  image_path = f"{unique_id}.png"
25
 
26
- _image.save(image_path) # Save the image using the underscore variable
 
27
 
28
  try:
29
- # Use the model to extract text
30
- res = _model.chat(_tokenizer, image_path, ocr_type='ocr') # Extract plain text
31
- return res
 
 
 
32
  except Exception as e:
33
  return f"Error: {str(e)}"
34
  finally:
@@ -37,48 +51,44 @@ def run_GOT(_image, _tokenizer, _model):
37
  os.remove(image_path)
38
 
39
  # Function to highlight keyword in text
40
- def highlight_keyword(text, keyword):
41
- if keyword:
42
- highlighted_text = text.replace(keyword, f"<mark>{keyword}</mark>")
43
- return highlighted_text
44
- return text
45
 
46
  # Streamlit App
47
- st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide")
 
48
 
49
- # Creating two columns
50
- left_col, right_col = st.columns(2)
51
-
52
- with left_col:
53
- uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
54
-
55
- with right_col:
56
- # Model selection in the right column
57
- model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"])
58
 
59
  if uploaded_image:
60
  image = Image.open(uploaded_image)
 
61
 
62
- with left_col:
63
- st.image(image, caption='Uploaded Image', use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- with right_col:
66
- if st.button("Run OCR"):
67
- with st.spinner("Processing..."):
68
- # Load the selected model (cached using @st.cache_resource)
69
- tokenizer, model = load_model(model_option)
70
-
71
- # Run OCR and cache the result using @st.cache_data
72
- result_text = run_GOT(image, tokenizer, model) # Pass the original image here
73
-
74
- if "Error" not in result_text:
75
- # Keyword input for search
76
- keyword = st.text_input("Enter a keyword to highlight")
77
-
78
- # Highlight keyword in the extracted text
79
- highlighted_text = highlight_keyword(result_text, keyword)
80
-
81
- # Display the extracted text
82
- st.markdown(highlighted_text, unsafe_allow_html=True)
83
- else:
84
- st.error(result_text)
 
 
 
1
  from transformers import AutoModel, AutoTokenizer
2
+ import streamlit as st
3
  from PIL import Image
4
+ import re
5
+ import os
6
  import uuid
7
 
8
+ # Load the model and tokenizer only once
9
+ if "model" not in st.session_state or "tokenizer" not in st.session_state:
10
+ @st.cache_resource
11
+ def load_model(model_name):
12
+ if model_name == "OCR for English or Hindi (CPU)":
13
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
14
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
15
+ model = model.eval()
16
+ elif model_name == "OCR for English (GPU)":
17
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
18
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
19
+ model = model.eval().to('cuda')
20
+ return model, tokenizer
21
+
22
+ # Load and store in session state
23
+ model_option = st.selectbox("Select Model", ["OCR for English or Hindi (CPU)", "OCR for English (GPU)"])
24
+ model, tokenizer = load_model(model_option)
25
+ st.session_state["model"] = model
26
+ st.session_state["tokenizer"] = tokenizer
27
+ else:
28
+ model = st.session_state["model"]
29
+ tokenizer = st.session_state["tokenizer"]
30
 
31
  # Function to run the GOT model for multilingual OCR
32
+ def run_ocr(image, model, tokenizer):
 
33
  unique_id = str(uuid.uuid4())
34
  image_path = f"{unique_id}.png"
35
 
36
+ # Save image to disk
37
+ image.save(image_path)
38
 
39
  try:
40
+ # Use the model to extract text from the image
41
+ res = model.chat(tokenizer, image_path, ocr_type='ocr')
42
+ if isinstance(res, str):
43
+ return res
44
+ else:
45
+ return str(res)
46
  except Exception as e:
47
  return f"Error: {str(e)}"
48
  finally:
 
51
  os.remove(image_path)
52
 
53
  # Function to highlight keyword in text
54
+ def highlight_text(text, search_term):
55
+ if not search_term:
56
+ return text
57
+ pattern = re.compile(re.escape(search_term), re.IGNORECASE)
58
+ return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
59
 
60
  # Streamlit App
61
+ st.title("GOT-OCR Multilingual Demo")
62
+ st.write("Upload an image for OCR")
63
 
64
+ # Upload image
65
+ uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
 
 
 
 
 
 
 
66
 
67
  if uploaded_image:
68
  image = Image.open(uploaded_image)
69
+ st.image(image, caption='Uploaded Image', use_column_width=True)
70
 
71
+ if st.button("Run OCR"):
72
+ with st.spinner("Processing..."):
73
+ # Run OCR and store the result in session state
74
+ result_text = run_ocr(image, model, tokenizer)
75
+ if "Error" not in result_text:
76
+ st.session_state["extracted_text"] = result_text # Store the result in session state
77
+ else:
78
+ st.error(result_text)
79
+
80
+ # Display the extracted text if it exists in session state
81
+ if "extracted_text" in st.session_state:
82
+ extracted_text = st.session_state["extracted_text"]
83
+
84
+ st.subheader("Extracted Text:")
85
+ st.text(extracted_text) # Display the raw extracted text
86
+
87
+ # Keyword input for search
88
+ search_term = st.text_input("Enter a word or phrase to highlight:")
89
 
90
+ # Highlight keyword in the extracted text
91
+ if search_term:
92
+ highlighted_text = highlight_text(extracted_text, search_term)
93
+ # Display the highlighted text using markdown
94
+ st.markdown(highlighted_text, unsafe_allow_html=True)