Spaces:
Sleeping
Sleeping
VinayHajare
commited on
Create ocr_processor.py
Browse files- ocr_processor.py +208 -0
ocr_processor.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Dict, Any, List, Union
|
3 |
+
import os
|
4 |
+
import base64
|
5 |
+
import requests
|
6 |
+
from tqdm import tqdm
|
7 |
+
import concurrent.futures
|
8 |
+
from pathlib import Path
|
9 |
+
import cv2
|
10 |
+
from pdf2image import convert_from_path
|
11 |
+
|
12 |
+
class OCRProcessor:
|
13 |
+
def __init__(self, model_name: str = "llama3.2-vision:11b",
|
14 |
+
base_url: str = "http://localhost:11434/api/generate",
|
15 |
+
max_workers: int = 1):
|
16 |
+
|
17 |
+
self.model_name = model_name
|
18 |
+
self.base_url = base_url
|
19 |
+
self.max_workers = max_workers
|
20 |
+
|
21 |
+
def _encode_image(self, image_path: str) -> str:
|
22 |
+
"""Convert image to base64 string"""
|
23 |
+
with open(image_path, "rb") as image_file:
|
24 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
25 |
+
|
26 |
+
def _preprocess_image(self, image_path: str) -> str:
|
27 |
+
"""
|
28 |
+
Preprocess image before OCR:
|
29 |
+
- Convert PDF to image if needed
|
30 |
+
- Auto-rotate
|
31 |
+
- Enhance contrast
|
32 |
+
- Reduce noise
|
33 |
+
"""
|
34 |
+
# Handle PDF files
|
35 |
+
if image_path.lower().endswith('.pdf'):
|
36 |
+
pages = convert_from_path(image_path)
|
37 |
+
if not pages:
|
38 |
+
raise ValueError("Could not convert PDF to image")
|
39 |
+
# Save first page as temporary image
|
40 |
+
temp_path = f"{image_path}_temp.jpg"
|
41 |
+
pages[0].save(temp_path, 'JPEG')
|
42 |
+
image_path = temp_path
|
43 |
+
|
44 |
+
# Read image
|
45 |
+
image = cv2.imread(image_path)
|
46 |
+
if image is None:
|
47 |
+
raise ValueError(f"Could not read image at {image_path}")
|
48 |
+
|
49 |
+
# Convert to grayscale
|
50 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
51 |
+
|
52 |
+
# Enhance contrast using CLAHE
|
53 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
54 |
+
enhanced = clahe.apply(gray)
|
55 |
+
|
56 |
+
# Denoise
|
57 |
+
denoised = cv2.fastNlMeansDenoising(enhanced)
|
58 |
+
|
59 |
+
# Auto-rotate if needed
|
60 |
+
# TODO: Implement rotation detection and correction
|
61 |
+
|
62 |
+
# Save preprocessed image
|
63 |
+
preprocessed_path = f"{image_path}_preprocessed.jpg"
|
64 |
+
cv2.imwrite(preprocessed_path, denoised)
|
65 |
+
|
66 |
+
return preprocessed_path
|
67 |
+
|
68 |
+
def process_image(self, image_path: str, format_type: str = "markdown", preprocess: bool = True) -> str:
|
69 |
+
"""
|
70 |
+
Process an image and extract text in the specified format
|
71 |
+
|
72 |
+
Args:
|
73 |
+
image_path: Path to the image file
|
74 |
+
format_type: One of ["markdown", "text", "json", "structured", "key_value"]
|
75 |
+
preprocess: Whether to apply image preprocessing
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
if preprocess:
|
79 |
+
image_path = self._preprocess_image(image_path)
|
80 |
+
|
81 |
+
image_base64 = self._encode_image(image_path)
|
82 |
+
|
83 |
+
# Clean up temporary files
|
84 |
+
if image_path.endswith(('_preprocessed.jpg', '_temp.jpg')):
|
85 |
+
os.remove(image_path)
|
86 |
+
|
87 |
+
# Generic prompt templates for different formats
|
88 |
+
prompts = {
|
89 |
+
"markdown": """Please look at this image and extract all the text content. Format the output in markdown:
|
90 |
+
- Use headers (# ## ###) for titles and sections
|
91 |
+
- Use bullet points (-) for lists
|
92 |
+
- Use proper markdown formatting for emphasis and structure
|
93 |
+
- Preserve the original text hierarchy and formatting as much as possible""",
|
94 |
+
|
95 |
+
"text": """Please look at this image and extract all the text content.
|
96 |
+
Provide the output as plain text, maintaining the original layout and line breaks where appropriate.
|
97 |
+
Include all visible text from the image.""",
|
98 |
+
|
99 |
+
"json": """Please look at this image and extract all the text content. Structure the output as JSON with these guidelines:
|
100 |
+
- Identify different sections or components
|
101 |
+
- Use appropriate keys for different text elements
|
102 |
+
- Maintain the hierarchical structure of the content
|
103 |
+
- Include all visible text from the image""",
|
104 |
+
|
105 |
+
"structured": """Please look at this image and extract all the text content, focusing on structural elements:
|
106 |
+
- Identify and format any tables
|
107 |
+
- Extract lists and maintain their structure
|
108 |
+
- Preserve any hierarchical relationships
|
109 |
+
- Format sections and subsections clearly""",
|
110 |
+
|
111 |
+
"key_value": """Please look at this image and extract text that appears in key-value pairs:
|
112 |
+
- Look for labels and their associated values
|
113 |
+
- Extract form fields and their contents
|
114 |
+
- Identify any paired information
|
115 |
+
- Present each pair on a new line as 'key: value'"""
|
116 |
+
}
|
117 |
+
|
118 |
+
# Get the appropriate prompt
|
119 |
+
prompt = prompts.get(format_type, prompts["text"])
|
120 |
+
|
121 |
+
# Prepare the request payload
|
122 |
+
payload = {
|
123 |
+
"model": self.model_name,
|
124 |
+
"prompt": prompt,
|
125 |
+
"stream": False,
|
126 |
+
"images": [image_base64]
|
127 |
+
}
|
128 |
+
|
129 |
+
# Make the API call to Ollama
|
130 |
+
response = requests.post(self.base_url, json=payload)
|
131 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
132 |
+
|
133 |
+
result = response.json().get("response", "")
|
134 |
+
|
135 |
+
# Clean up the result if needed
|
136 |
+
if format_type == "json":
|
137 |
+
try:
|
138 |
+
# Try to parse and re-format JSON if it's valid
|
139 |
+
json_data = json.loads(result)
|
140 |
+
return json.dumps(json_data, indent=2)
|
141 |
+
except json.JSONDecodeError:
|
142 |
+
# If JSON parsing fails, return the raw result
|
143 |
+
return result
|
144 |
+
|
145 |
+
return result
|
146 |
+
except Exception as e:
|
147 |
+
return f"Error processing image: {str(e)}"
|
148 |
+
|
149 |
+
def process_batch(
|
150 |
+
self,
|
151 |
+
input_path: Union[str, List[str]],
|
152 |
+
format_type: str = "markdown",
|
153 |
+
recursive: bool = False,
|
154 |
+
preprocess: bool = True
|
155 |
+
) -> Dict[str, Any]:
|
156 |
+
"""
|
157 |
+
Process multiple images in batch
|
158 |
+
|
159 |
+
Args:
|
160 |
+
input_path: Path to directory or list of image paths
|
161 |
+
format_type: Output format type
|
162 |
+
recursive: Whether to search directories recursively
|
163 |
+
preprocess: Whether to apply image preprocessing
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Dictionary with results and statistics
|
167 |
+
"""
|
168 |
+
# Collect all image paths
|
169 |
+
image_paths = []
|
170 |
+
if isinstance(input_path, str):
|
171 |
+
base_path = Path(input_path)
|
172 |
+
if base_path.is_dir():
|
173 |
+
pattern = '**/*' if recursive else '*'
|
174 |
+
for ext in ['.png', '.jpg', '.jpeg', '.pdf', '.tiff']:
|
175 |
+
image_paths.extend(base_path.glob(f'{pattern}{ext}'))
|
176 |
+
else:
|
177 |
+
image_paths = [base_path]
|
178 |
+
else:
|
179 |
+
image_paths = [Path(p) for p in input_path]
|
180 |
+
|
181 |
+
results = {}
|
182 |
+
errors = {}
|
183 |
+
|
184 |
+
# Process images in parallel with progress bar
|
185 |
+
with tqdm(total=len(image_paths), desc="Processing images") as pbar:
|
186 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
187 |
+
future_to_path = {
|
188 |
+
executor.submit(self.process_image, str(path), format_type, preprocess): path
|
189 |
+
for path in image_paths
|
190 |
+
}
|
191 |
+
|
192 |
+
for future in concurrent.futures.as_completed(future_to_path):
|
193 |
+
path = future_to_path[future]
|
194 |
+
try:
|
195 |
+
results[str(path)] = future.result()
|
196 |
+
except Exception as e:
|
197 |
+
errors[str(path)] = str(e)
|
198 |
+
pbar.update(1)
|
199 |
+
|
200 |
+
return {
|
201 |
+
"results": results,
|
202 |
+
"errors": errors,
|
203 |
+
"statistics": {
|
204 |
+
"total": len(image_paths),
|
205 |
+
"successful": len(results),
|
206 |
+
"failed": len(errors)
|
207 |
+
}
|
208 |
+
}
|