Divyansh12 commited on
Commit
fa9edbf
·
verified ·
1 Parent(s): 429f75e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -49
app.py CHANGED
@@ -1,59 +1,129 @@
1
 
2
  import streamlit as st
3
- from PIL import Image
4
- import re
5
  from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
 
6
 
7
- st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
 
 
 
8
 
9
- @st.cache_resource
10
- def load_model():
11
- # Load the tokenizer and model for processing images
12
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
13
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
14
- return tokenizer, model
15
 
16
- def extract_text(image, tokenizer, model):
17
- # Preprocess the image and extract text using the model
18
- inputs = tokenizer(images=image, return_tensors="pt") # Adjust based on how the model expects inputs
19
- generated_ids = model.generate(**inputs)
20
- extracted_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
21
- return extracted_text
22
 
23
- def highlight_matches(text, keywords):
24
- # Highlight keywords in the extracted text
25
- pattern = re.compile(f"({re.escape(keywords)})", re.IGNORECASE)
26
- highlighted_text = pattern.sub(r"<mark>\1</mark>", text)
27
- return highlighted_text
 
 
28
 
29
- def main():
30
- st.title("OCR Text Extractor using Qwen Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Load model and tokenizer
33
- tokenizer, model = load_model()
34
 
35
- # Upload Image
36
- uploaded_file = st.file_uploader("Upload an image for OCR", type=["png", "jpg", "jpeg"])
37
-
38
- if uploaded_file:
39
- image = Image.open(uploaded_file)
40
- st.image(image, caption="Uploaded Image", use_column_width=True)
41
-
42
- # Extract text from the image using the model
43
- with st.spinner("Extracting text from the image..."):
44
- extracted_text = extract_text(image, tokenizer, model)
45
-
46
- st.subheader("Extracted Text")
47
- st.text_area("Text from Image", extracted_text, height=300)
48
-
49
- # Keyword search
50
- st.subheader("Keyword Search")
51
- keywords = st.text_input("Enter keywords to search:")
52
-
53
- if st.button("Search"):
54
- highlighted_text = highlight_matches(extracted_text, keywords)
55
- st.subheader("Search Results")
56
- st.markdown(highlighted_text, unsafe_allow_html=True)
57
-
58
- if __name__ == "__main__":
59
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import streamlit as st
 
 
3
  from transformers import AutoModel, AutoTokenizer
4
+ from PIL import Image
5
+ import os
6
+ import base64
7
+ import uuid
8
+ import time
9
+ import shutil
10
+ from pathlib import Path
11
 
12
+ # Load tokenizer and model on CPU
13
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
14
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True)
15
+ model = model.eval() # Use CPU
16
 
17
+ # Define folders for uploads and results
18
+ UPLOAD_FOLDER = "./uploads"
19
+ RESULTS_FOLDER = "./results"
 
 
 
20
 
21
+ for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
22
+ if not os.path.exists(folder):
23
+ os.makedirs(folder)
 
 
 
24
 
25
+ # Function to run the GOT model
26
+ def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
27
+ unique_id = str(uuid.uuid4())
28
+ image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
29
+ result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
30
+
31
+ image.save(image_path)
32
 
33
+ try:
34
+ if got_mode == "plain texts OCR":
35
+ res = model.chat(tokenizer, image_path, ocr_type='ocr')
36
+ return res, None
37
+ elif got_mode == "format texts OCR":
38
+ res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
39
+ elif got_mode == "plain multi-crop OCR":
40
+ res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
41
+ return res, None
42
+ elif got_mode == "format multi-crop OCR":
43
+ res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
44
+ elif got_mode == "plain fine-grained OCR":
45
+ res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
46
+ return res, None
47
+ elif got_mode == "format fine-grained OCR":
48
+ res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
49
 
50
+ res_markdown = res
 
51
 
52
+ if "format" in got_mode and os.path.exists(result_path):
53
+ with open(result_path, 'r') as f:
54
+ html_content = f.read()
55
+ encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
56
+ iframe_src = f"data:text/html;base64,{encoded_html}"
57
+ iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
58
+ return res_markdown, iframe
59
+ else:
60
+ return res_markdown, None
61
+ except Exception as e:
62
+ return f"Error: {str(e)}", None
63
+ finally:
64
+ if os.path.exists(image_path):
65
+ os.remove(image_path)
66
+
67
+ # Function to clean up old files
68
+ def cleanup_old_files():
69
+ current_time = time.time()
70
+ for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
71
+ for file_path in Path(folder).glob('*'):
72
+ if current_time - file_path.stat().st_mtime > 3600: # 1 hour
73
+ file_path.unlink()
74
+
75
+ # Streamlit App
76
+ st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide")
77
+
78
+ st.markdown("""
79
+ <h2> <span style="color: #ff6600">General OCR Theory</span>: Towards OCR-2.0 via a Unified End-to-end Model</h2>
80
+ <a href="https://huggingface.co/ucaslcl/GOT-OCR2_0">[😊 Hugging Face]</a>
81
+ <a href="https://arxiv.org/abs/2409.01704">[📜 Paper]</a>
82
+ <a href="https://github.com/Ucas-HaoranWei/GOT-OCR2.0/">[🌟 GitHub]</a>
83
+ """, unsafe_allow_html=True)
84
+
85
+ st.markdown("""
86
+ "🔥🔥🔥This is the official online demo of the GOT-OCR-2.0 model!!!"
87
+ ### Demo Guidelines
88
+ - You need to upload your image below and choose one mode of GOT, then click "Submit" to run the GOT model. More characters will result in longer wait times.
89
+ - **plain texts OCR & format texts OCR**: The two modes are for the image-level OCR.
90
+ - **plain multi-crop OCR & format multi-crop OCR**: For images with more complex content, you can achieve higher-quality results with these modes.
91
+ - **plain fine-grained OCR & format fine-grained OCR**: In these modes, you can specify fine-grained regions on the input image for more flexible OCR. Fine-grained regions can be coordinates of the box, red color, blue color, or green color.
92
+ """)
93
+
94
+ uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
95
+
96
+ if uploaded_image:
97
+ image = Image.open(uploaded_image)
98
+ st.image(image, caption='Uploaded Image', use_column_width=True)
99
+
100
+ got_mode = st.selectbox("Choose one mode of GOT", [
101
+ "plain texts OCR",
102
+ "format texts OCR",
103
+ "plain multi-crop OCR",
104
+ "format multi-crop OCR",
105
+ "plain fine-grained OCR",
106
+ "format fine-grained OCR",
107
+ ])
108
+
109
+ fine_grained_mode = None
110
+ ocr_color = ""
111
+ ocr_box = ""
112
+
113
+ if "fine-grained" in got_mode:
114
+ fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
115
+ if fine_grained_mode == "box":
116
+ ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
117
+ elif fine_grained_mode == "color":
118
+ ocr_color = st.selectbox("Color list", ["red", "green", "blue"])
119
+
120
+ if st.button("Submit"):
121
+ with st.spinner("Processing..."):
122
+ result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
123
+ st.text_area("GOT Output", result_text, height=200)
124
+
125
+ if html_result:
126
+ st.markdown(html_result, unsafe_allow_html=True)
127
+
128
+ # Cleanup old files
129
+ cleanup_old_files()