GAS17 commited on
Commit
d53445d
verified
1 Parent(s): 0516ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -29
app.py CHANGED
@@ -1,38 +1,92 @@
1
- import gradio as gr
2
- from doctr.io import DocumentFile
3
- from doctr.models import ocr_predictor
4
 
5
- # Cargar el modelo preentrenado
6
- model = ocr_predictor(pretrained=True)
7
 
8
- def process_file(file):
9
- """Procesa un archivo (PDF o imagen) con docTR y retorna el texto extra铆do."""
10
- if file is None:
11
- return "Por favor, sube un archivo."
12
-
13
- # Leer el archivo subido
14
- doc = DocumentFile.from_pdf(file.name) if file.name.endswith('.pdf') else DocumentFile.from_images(file.name)
15
 
16
- # Realizar OCR
17
- result = model(doc)
18
 
19
- # Extraer el texto y retornarlo
20
- extracted_text = "\n".join([block['text'] for page in result.pages for block in page['blocks']])
21
- return extracted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Configuraci贸n de la interfaz de Gradio
24
- with gr.Blocks() as demo:
25
- gr.Markdown("## OCR con docTR")
26
- gr.Markdown("Sube un archivo PDF o una imagen para extraer texto utilizando un modelo preentrenado de docTR.")
27
 
28
- with gr.Row():
29
- input_file = gr.File(label="Subir archivo (PDF o imagen)")
30
- output_text = gr.Textbox(label="Texto extra铆do", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- process_button = gr.Button("Procesar archivo")
33
 
34
- process_button.click(fn=process_file, inputs=[input_file], outputs=[output_text])
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Ejecutar la app
37
- if __name__ == "__main__":
38
- demo.launch()
 
1
+ # Copyright (C) 2021-2024, Mindee.
 
 
2
 
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
 
6
+ import numpy as np
7
+ import torch
 
 
 
 
 
8
 
9
+ from doctr.models import ocr_predictor
10
+ from doctr.models.predictor import OCRPredictor
11
 
12
+ DET_ARCHS = [
13
+ "fast_base",
14
+ "fast_small",
15
+ "fast_tiny",
16
+ "db_resnet50",
17
+ "db_resnet34",
18
+ "db_mobilenet_v3_large",
19
+ "linknet_resnet18",
20
+ "linknet_resnet34",
21
+ "linknet_resnet50",
22
+ ]
23
+ RECO_ARCHS = [
24
+ "crnn_vgg16_bn",
25
+ "crnn_mobilenet_v3_small",
26
+ "crnn_mobilenet_v3_large",
27
+ "master",
28
+ "sar_resnet31",
29
+ "vitstr_small",
30
+ "vitstr_base",
31
+ "parseq",
32
+ ]
33
 
 
 
 
 
34
 
35
+ def load_predictor(
36
+ det_arch: str,
37
+ reco_arch: str,
38
+ assume_straight_pages: bool,
39
+ straighten_pages: bool,
40
+ export_as_straight_boxes: bool,
41
+ disable_page_orientation: bool,
42
+ disable_crop_orientation: bool,
43
+ bin_thresh: float,
44
+ box_thresh: float,
45
+ device: torch.device,
46
+ ) -> OCRPredictor:
47
+ """Load a predictor from doctr.models
48
+ Args:
49
+ det_arch: detection architecture
50
+ reco_arch: recognition architecture
51
+ assume_straight_pages: whether to assume straight pages or not
52
+ straighten_pages: whether to straighten rotated pages or not
53
+ export_as_straight_boxes: whether to export boxes as straight or not
54
+ disable_page_orientation: whether to disable page orientation or not
55
+ disable_crop_orientation: whether to disable crop orientation or not
56
+ bin_thresh: binarization threshold for the segmentation map
57
+ box_thresh: minimal objectness score to consider a box
58
+ device: torch.device, the device to load the predictor on
59
+ Returns:
60
+ instance of OCRPredictor
61
+ """
62
+ predictor = ocr_predictor(
63
+ det_arch,
64
+ reco_arch,
65
+ pretrained=True,
66
+ assume_straight_pages=assume_straight_pages,
67
+ straighten_pages=straighten_pages,
68
+ export_as_straight_boxes=export_as_straight_boxes,
69
+ detect_orientation=not assume_straight_pages,
70
+ disable_page_orientation=disable_page_orientation,
71
+ disable_crop_orientation=disable_crop_orientation,
72
+ ).to(device)
73
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
74
+ predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
75
+ return predictor
76
 
 
77
 
78
+ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
79
+ """Forward an image through the predictor
80
+ Args:
81
+ predictor: instance of OCRPredictor
82
+ image: image to process
83
+ device: torch.device, the device to process the image on
84
+ Returns:
85
+ segmentation map
86
+ """
87
+ with torch.no_grad():
88
+ processed_batches = predictor.det_predictor.pre_processor([image])
89
+ out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
90
+ seg_map = out["out_map"].to("cpu").numpy()
91
 
92
+ return seg_map