Mageia commited on
Commit
a5dcb7e
1 Parent(s): 118959f
Files changed (3) hide show
  1. app-ocr.py +82 -0
  2. app.py +60 -4
  3. requirements.txt +12 -1
app-ocr.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+
4
+ import gradio as gr
5
+ import spaces
6
+ import torch
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+ model_name = "ucaslcl/GOT-OCR2_0"
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
14
+ model = model.eval().to(device)
15
+
16
+
17
+ @spaces.GPU()
18
+ def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
19
+ if image is None:
20
+ return "错误:未提供图片"
21
+
22
+ try:
23
+ image_path = image
24
+ result_path = f"{os.path.splitext(image_path)[0]}_result.html"
25
+
26
+ progress(0, desc="开始处理...")
27
+
28
+ if "plain" in got_mode:
29
+ progress(0.3, desc="执行OCR识别...")
30
+ if "multi-crop" in got_mode:
31
+ res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
32
+ else:
33
+ res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
34
+ progress(1, desc="处理完成")
35
+ return res
36
+ elif "format" in got_mode:
37
+ progress(0.3, desc="执行OCR识别...")
38
+ if "multi-crop" in got_mode:
39
+ res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
40
+ else:
41
+ res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
42
+
43
+ progress(0.7, desc="生成结果...")
44
+ if os.path.exists(result_path):
45
+ with open(result_path, "r", encoding="utf-8") as f:
46
+ html_content = f.read()
47
+ encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
48
+ data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
49
+ preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
50
+ download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
51
+ progress(1, desc="处理完成")
52
+ return f"{download_link}\n\n{preview}"
53
+
54
+ return "错误: 未知的OCR模式"
55
+ except Exception as e:
56
+ return f"错误: {str(e)}"
57
+
58
+
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# OCR 图像识别")
61
+
62
+ with gr.Row():
63
+ image_input = gr.Image(type="filepath", label="上传图片")
64
+
65
+ got_mode = gr.Dropdown(
66
+ choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
67
+ label="OCR模式",
68
+ value="plain texts OCR",
69
+ )
70
+
71
+ with gr.Row():
72
+ ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
73
+ ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")
74
+
75
+ submit_button = gr.Button("开始OCR识别")
76
+
77
+ output = gr.HTML(label="识别结果")
78
+
79
+ submit_button.click(ocr_process, inputs=[image_input, got_mode, ocr_color, ocr_box], outputs=output)
80
+
81
+ if __name__ == "__main__":
82
+ demo.launch()
app.py CHANGED
@@ -1,8 +1,64 @@
1
- from fastapi import FastAPI
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+
4
+ import torch
5
+ from fastapi import FastAPI, File, Form, UploadFile
6
+ from transformers import AutoModel, AutoTokenizer
7
 
8
  app = FastAPI()
9
 
10
+ # 初始化模型
11
+ model_name = "ucaslcl/GOT-OCR2_0"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
16
+ model = model.eval().to(device)
17
+
18
+
19
+ # OCR处理函数
20
+ async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""):
21
+ try:
22
+ if "plain" in got_mode:
23
+ if "multi-crop" in got_mode:
24
+ res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
25
+ else:
26
+ res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
27
+ return res
28
+ elif "format" in got_mode:
29
+ result_path = f"{os.path.splitext(image_path)[0]}_result.html"
30
+ if "multi-crop" in got_mode:
31
+ res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
32
+ else:
33
+ res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
34
+
35
+ if os.path.exists(result_path):
36
+ with open(result_path, "r", encoding="utf-8") as f:
37
+ html_content = f.read()
38
+ encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
39
+ return {"html_content": encoded_html}
40
+
41
+ return {"error": "未知的OCR模式"}
42
+ except Exception as e:
43
+ return {"error": str(e)}
44
 
45
+
46
+ @app.post("/ocr")
47
+ async def ocr_api(image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")):
48
+ # 保存上传的图片
49
+ image_path = f"temp_{image.filename}"
50
+ with open(image_path, "wb") as buffer:
51
+ buffer.write(await image.read())
52
+
53
+ # 处理OCR
54
+ result = await ocr_process(image_path, got_mode, ocr_color, ocr_box)
55
+
56
+ # 删除临时文件
57
+ os.remove(image_path)
58
+
59
+ return result
60
+
61
+
62
+ @app.get("/")
63
+ def read_root():
64
+ return {"message": "欢迎使用OCR API"}
requirements.txt CHANGED
@@ -1,3 +1,14 @@
1
  fastapi
2
  uvicorn[standard]
3
-
 
 
 
 
 
 
 
 
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ PyMuPDF
4
+ verovio
5
+ gradio
6
+ numpy
7
+ modelscope
8
+ Pillow
9
+ tiktoken
10
+ transformers
11
+ torch
12
+ torchvision
13
+ accelerate
14
+ spaces