Divyansh12 commited on
Commit
2dd690c
·
verified ·
1 Parent(s): c90dcb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -55
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import AutoModel, AutoTokenizer
2
  import streamlit as st
3
  from PIL import Image
@@ -5,61 +6,44 @@ import re
5
  import os
6
  import uuid
7
 
 
 
 
8
  # Load the model and tokenizer only once
9
- if "model" not in st.session_state or "tokenizer" not in st.session_state:
10
- @st.cache_resource
11
- def load_model(model_name):
12
- if model_name == "OCR for English or Hindi (CPU)":
13
- tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
14
- model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
15
- model = model.eval()
16
- elif model_name == "OCR for English (GPU)":
17
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
18
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
19
- model = model.eval().to('cuda')
20
- return model, tokenizer
21
 
22
- # Load and store in session state
23
- model_option = "OCR for English or Hindi (CPU)" # Default value for loading purposes
24
- model, tokenizer = load_model(model_option)
25
- st.session_state["model"] = model
26
- st.session_state["tokenizer"] = tokenizer
27
- else:
28
- model = st.session_state["model"]
29
- tokenizer = st.session_state["tokenizer"]
30
 
31
  # Function to run the GOT model for multilingual OCR
32
  def run_ocr(image, model, tokenizer):
33
- unique_id = str(uuid.uuid4())
34
- image_path = f"{unique_id}.png"
35
-
36
- # Save image to disk
37
  image.save(image_path)
38
 
39
  try:
40
- # Use the model to extract text from the image
41
  res = model.chat(tokenizer, image_path, ocr_type='ocr')
42
- if isinstance(res, str):
43
- return res
44
- else:
45
- return str(res)
46
  except Exception as e:
47
  return f"Error: {str(e)}"
48
  finally:
49
- # Clean up the saved image
50
- if os.path.exists(image_path):
51
- os.remove(image_path)
52
 
53
  # Function to highlight keyword in text
54
  def highlight_text(text, search_term):
55
- if not search_term:
56
- return text
57
- pattern = re.compile(re.escape(search_term), re.IGNORECASE)
58
- return pattern.sub(lambda m: f'<span style="background-color: red;">{m.group()}</span>', text)
59
 
60
  # Streamlit App
61
- st.title("GOT-OCR Multilingual Demo")
62
- st.write("Upload an image for OCR")
63
 
64
  # Create two columns
65
  col1, col2 = st.columns(2)
@@ -73,30 +57,22 @@ with col1:
73
 
74
  # Right column - Model selection, options, and displaying extracted text
75
  with col2:
76
- model_option = st.selectbox("Select Model", ["OCR for English or Hindi (CPU)", "OCR for English (GPU)"])
77
 
78
- if st.button("Run OCR"):
79
- with st.spinner("Processing..."):
80
- # Run OCR and store the result in session state
81
- if uploaded_image:
82
  result_text = run_ocr(image, model, tokenizer)
83
  if "Error" not in result_text:
84
- st.session_state["extracted_text"] = result_text # Store the result in session state
85
  else:
86
  st.error(result_text)
87
- else:
88
- st.error("Please upload an image before running OCR.")
89
 
90
  # Display the extracted text if it exists in session state
91
  if "extracted_text" in st.session_state:
92
- extracted_text = st.session_state["extracted_text"]
93
-
94
- # Keyword input for search
95
  search_term = st.text_input("Enter a word or phrase to highlight:")
96
-
97
- # Highlight keyword in the extracted text
98
- highlighted_text = highlight_text(extracted_text, search_term)
99
-
100
- # Display the highlighted text using markdown
101
  st.subheader("Extracted Text:")
102
- st.markdown(f'<div style="white-space: pre-wrap;">{highlighted_text}</div>', unsafe_allow_html=True)
 
1
+
2
  from transformers import AutoModel, AutoTokenizer
3
  import streamlit as st
4
  from PIL import Image
 
6
  import os
7
  import uuid
8
 
9
+ # Set the page layout to wide
10
+ st.set_page_config(layout="wide")
11
+
12
  # Load the model and tokenizer only once
13
+ @st.cache_resource
14
+ def load_model(model_name):
15
+ if model_name == "OCR on CPU":
16
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
17
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id).eval()
18
+ else:
19
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
20
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id).eval().to('cuda')
21
+ return model, tokenizer
 
 
 
22
 
23
+ if "model" not in st.session_state or "tokenizer" not in st.session_state:
24
+ model, tokenizer = load_model("OCR for English or Hindi (CPU)")
25
+ st.session_state.update({"model": model, "tokenizer": tokenizer})
 
 
 
 
 
26
 
27
  # Function to run the GOT model for multilingual OCR
28
  def run_ocr(image, model, tokenizer):
29
+ image_path = f"{uuid.uuid4()}.png"
 
 
 
30
  image.save(image_path)
31
 
32
  try:
 
33
  res = model.chat(tokenizer, image_path, ocr_type='ocr')
34
+ return res if isinstance(res, str) else str(res)
 
 
 
35
  except Exception as e:
36
  return f"Error: {str(e)}"
37
  finally:
38
+ os.remove(image_path)
 
 
39
 
40
  # Function to highlight keyword in text
41
  def highlight_text(text, search_term):
42
+ return re.sub(re.escape(search_term), lambda m: f'<span style="background-color: red;">{m.group()}</span>', text, flags=re.IGNORECASE) if search_term else text
 
 
 
43
 
44
  # Streamlit App
45
+ st.title(":blue[Object character recognition Application]")
46
+ st.write("Give your Image")
47
 
48
  # Create two columns
49
  col1, col2 = st.columns(2)
 
57
 
58
  # Right column - Model selection, options, and displaying extracted text
59
  with col2:
60
+ model_option = st.selectbox("Select Model", ["OCR on CPU", "OCR on GPU"])
61
 
62
+ if st.button("DO OCR "):
63
+ if uploaded_image:
64
+ with st.spinner("Processing..."):
65
+ model, tokenizer = load_model(model_option)
66
  result_text = run_ocr(image, model, tokenizer)
67
  if "Error" not in result_text:
68
+ st.session_state["extracted_text"] = result_text
69
  else:
70
  st.error(result_text)
71
+ else:
72
+ st.error("Please upload an image before running OCR.")
73
 
74
  # Display the extracted text if it exists in session state
75
  if "extracted_text" in st.session_state:
 
 
 
76
  search_term = st.text_input("Enter a word or phrase to highlight:")
 
 
 
 
 
77
  st.subheader("Extracted Text:")
78
+ st.markdown(f'<div style="white-space: pre-wrap;">{highlight_text(st.session_state["extracted_text"], search_term)}</div>', unsafe_allow_html=True)