Divyansh12 commited on
Commit
f07599d
·
verified ·
1 Parent(s): 4d862d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -182
app.py CHANGED
@@ -1,184 +1,64 @@
1
- import io
2
- from typing import List
3
-
4
- import pypdfium2
5
  import streamlit as st
6
- from surya.detection import batch_text_detection
7
- from surya.layout import batch_layout_detection
8
- from surya.model.detection.model import load_model, load_processor
9
- from surya.model.recognition.model import load_model as load_rec_model
10
- from surya.model.recognition.processor import load_processor as load_rec_processor
11
- from surya.model.ordering.processor import load_processor as load_order_processor
12
- from surya.model.ordering.model import load_model as load_order_model
13
- from surya.ordering import batch_ordering
14
- from surya.postprocessing.heatmap import draw_polys_on_image
15
- from surya.ocr import run_ocr
16
- from surya.postprocessing.text import draw_text_on_image
17
- from PIL import Image
18
- from surya.languages import CODE_TO_LANGUAGE
19
- from surya.input.langs import replace_lang_with_code
20
- from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
21
- from surya.settings import settings
22
-
23
- @st.cache_resource()
24
- def load_det_cached():
25
- checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
26
- return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)
27
-
28
-
29
- @st.cache_resource()
30
- def load_rec_cached():
31
- return load_rec_model(), load_rec_processor()
32
-
33
-
34
- @st.cache_resource()
35
- def load_layout_cached():
36
- return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
37
-
38
- @st.cache_resource()
39
- def load_order_cached():
40
- return load_order_model(), load_order_processor()
41
-
42
-
43
- def text_detection(img) -> (Image.Image, TextDetectionResult):
44
- pred = batch_text_detection([img], det_model, det_processor)[0]
45
- polygons = [p.polygon for p in pred.bboxes]
46
- det_img = draw_polys_on_image(polygons, img.copy())
47
- return det_img, pred
48
-
49
-
50
- def layout_detection(img) -> (Image.Image, LayoutResult):
51
- _, det_pred = text_detection(img)
52
- pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
53
- polygons = [p.polygon for p in pred.bboxes]
54
- labels = [p.label for p in pred.bboxes]
55
- layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels)
56
- return layout_img, pred
57
-
58
-
59
- def order_detection(img) -> (Image.Image, OrderResult):
60
- _, layout_pred = layout_detection(img)
61
- bboxes = [l.bbox for l in layout_pred.bboxes]
62
- pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
63
- polys = [l.polygon for l in pred.bboxes]
64
- positions = [str(l.position) for l in pred.bboxes]
65
- order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
66
- return order_img, pred
67
-
68
-
69
- # Function for OCR
70
- def ocr(img, langs: List[str]) -> (Image.Image, OCRResult):
71
- replace_lang_with_code(langs)
72
- img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
73
-
74
- bboxes = [l.bbox for l in img_pred.text_lines]
75
- text = [l.text for l in img_pred.text_lines]
76
- rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
77
- return rec_img, img_pred
78
-
79
-
80
- def open_pdf(pdf_file):
81
- stream = io.BytesIO(pdf_file.getvalue())
82
- return pypdfium2.PdfDocument(stream)
83
-
84
-
85
- @st.cache_data()
86
- def get_page_image(pdf_file, page_num, dpi=96):
87
- doc = open_pdf(pdf_file)
88
- renderer = doc.render(
89
- pypdfium2.PdfBitmap.to_pil,
90
- page_indices=[page_num - 1],
91
- scale=dpi / 72,
92
- )
93
- png = list(renderer)[0]
94
- png_image = png.convert("RGB")
95
- return png_image
96
-
97
 
98
- @st.cache_data()
99
- def page_count(pdf_file):
100
- doc = open_pdf(pdf_file)
101
- return len(doc)
102
-
103
-
104
- st.set_page_config(layout="wide")
105
- col1, col2 = st.columns([.5, .5])
106
-
107
- det_model, det_processor = load_det_cached()
108
- rec_model, rec_processor = load_rec_cached()
109
- layout_model, layout_processor = load_layout_cached()
110
- order_model, order_processor = load_order_cached()
111
-
112
-
113
- st.markdown("""
114
- # Surya OCR Demo
115
-
116
- This app will let you try surya, a multilingual OCR model. It supports text detection + layout analysis in any language, and text recognition in 90+ languages.
117
-
118
- Notes:
119
- - This works best on documents with printed text.
120
- - Preprocessing the image (e.g. increasing contrast) can improve results.
121
- - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).
122
- - This supports 90+ languages, see [here](https://github.com/VikParuchuri/surya/tree/master/surya/languages.py) for a full list.
123
-
124
- Find the project [here](https://github.com/VikParuchuri/surya).
125
- """)
126
-
127
- in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
128
- languages = st.sidebar.multiselect("Languages", sorted(list(CODE_TO_LANGUAGE.values())), default=[], max_selections=4, help="Select the languages in the image (if known) to improve OCR accuracy. Optional.")
129
-
130
- if in_file is None:
131
- st.stop()
132
-
133
- filetype = in_file.type
134
- whole_image = False
135
- if "pdf" in filetype:
136
- page_count = page_count(in_file)
137
- page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
138
-
139
- pil_image = get_page_image(in_file, page_number)
140
- else:
141
- pil_image = Image.open(in_file).convert("RGB")
142
-
143
- text_det = st.sidebar.button("Run Text Detection")
144
- text_rec = st.sidebar.button("Run OCR")
145
- layout_det = st.sidebar.button("Run Layout Analysis")
146
- order_det = st.sidebar.button("Run Reading Order")
147
-
148
- if pil_image is None:
149
- st.stop()
150
-
151
- # Run Text Detection
152
- if text_det:
153
- det_img, pred = text_detection(pil_image)
154
- with col1:
155
- st.image(det_img, caption="Detected Text", use_column_width=True)
156
- st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
157
-
158
-
159
- # Run layout
160
- if layout_det:
161
- layout_img, pred = layout_detection(pil_image)
162
- with col1:
163
- st.image(layout_img, caption="Detected Layout", use_column_width=True)
164
- st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)
165
-
166
- # Run OCR
167
- if text_rec:
168
- rec_img, pred = ocr(pil_image, languages)
169
- with col1:
170
- st.image(rec_img, caption="OCR Result", use_column_width=True)
171
- json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
172
- with json_tab:
173
- st.json(pred.model_dump(), expanded=True)
174
- with text_tab:
175
- st.text("\n".join([p.text for p in pred.text_lines]))
176
-
177
- if order_det:
178
- order_img, pred = order_detection(pil_image)
179
- with col1:
180
- st.image(order_img, caption="Reading Order", use_column_width=True)
181
- st.json(pred.model_dump(), expanded=True)
182
-
183
- with col2:
184
- st.image(pil_image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from PIL import Image
4
+ import re
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
8
+ device = "cpu"
9
+
10
+ @st.cache_resource
11
+ #def load_model():
12
+ #processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
13
+ #model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten', device_map='cpu')
14
+ #@st.cache_resource
15
+ def load_model():
16
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map='cpu')
17
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cpu', use_safetensors=True)
18
+ processor=tokenizer
19
+ return processor, model
20
+
21
+ def extract_text(image, processor, model):
22
+ # Preprocess the image and extract text
23
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
24
+ generated_ids = model.generate(pixel_values)
25
+ extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
+ return extracted_text
27
+
28
+ def highlight_matches(text, keywords):
29
+ # Highlight keywords in the extracted text
30
+ pattern = re.compile(f"({re.escape(keywords)})", re.IGNORECASE)
31
+ highlighted_text = pattern.sub(r"<mark>\1</mark>", text)
32
+ return highlighted_text
33
+
34
+ def main():
35
+ st.title("OCR Text Extractor using Hugging Face Model")
36
+
37
+ # Load model and processor
38
+ processor, model = load_model()
39
+
40
+ # Upload Image
41
+ uploaded_file = st.file_uploader("Upload an image for OCR", type=["png", "jpg", "jpeg"])
42
+
43
+ if uploaded_file:
44
+ image = Image.open(uploaded_file)
45
+ st.image(image, caption="Uploaded Image", use_column_width=True)
46
+
47
+ # Extract text from the image
48
+ with st.spinner("Extracting text from the image..."):
49
+ extracted_text = extract_text(image, processor, model)
50
+
51
+ st.subheader("Extracted Text")
52
+ st.text_area("Text from Image", extracted_text, height=300)
53
+
54
+ # Keyword search
55
+ st.subheader("Keyword Search")
56
+ keywords = st.text_input("Enter keywords to search:")
57
+
58
+ if st.button("Search"):
59
+ highlighted_text = highlight_matches(extracted_text, keywords)
60
+ st.subheader("Search Results")
61
+ st.markdown(highlighted_text, unsafe_allow_html=True)
62
+
63
+ if __name__ == "__main__":
64
+ main()