Spaces:
Sleeping
Sleeping
init
Browse files- app-ocr.py +82 -0
- app.py +60 -4
- requirements.txt +12 -1
app-ocr.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import spaces
|
6 |
+
import torch
|
7 |
+
from transformers import AutoModel, AutoTokenizer
|
8 |
+
|
9 |
+
model_name = "ucaslcl/GOT-OCR2_0"
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
13 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
|
14 |
+
model = model.eval().to(device)
|
15 |
+
|
16 |
+
|
17 |
+
@spaces.GPU()
|
18 |
+
def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
|
19 |
+
if image is None:
|
20 |
+
return "错误:未提供图片"
|
21 |
+
|
22 |
+
try:
|
23 |
+
image_path = image
|
24 |
+
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
|
25 |
+
|
26 |
+
progress(0, desc="开始处理...")
|
27 |
+
|
28 |
+
if "plain" in got_mode:
|
29 |
+
progress(0.3, desc="执行OCR识别...")
|
30 |
+
if "multi-crop" in got_mode:
|
31 |
+
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
|
32 |
+
else:
|
33 |
+
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
|
34 |
+
progress(1, desc="处理完成")
|
35 |
+
return res
|
36 |
+
elif "format" in got_mode:
|
37 |
+
progress(0.3, desc="执行OCR识别...")
|
38 |
+
if "multi-crop" in got_mode:
|
39 |
+
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
|
40 |
+
else:
|
41 |
+
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
|
42 |
+
|
43 |
+
progress(0.7, desc="生成结果...")
|
44 |
+
if os.path.exists(result_path):
|
45 |
+
with open(result_path, "r", encoding="utf-8") as f:
|
46 |
+
html_content = f.read()
|
47 |
+
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
|
48 |
+
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
|
49 |
+
preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
|
50 |
+
download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
|
51 |
+
progress(1, desc="处理完成")
|
52 |
+
return f"{download_link}\n\n{preview}"
|
53 |
+
|
54 |
+
return "错误: 未知的OCR模式"
|
55 |
+
except Exception as e:
|
56 |
+
return f"错误: {str(e)}"
|
57 |
+
|
58 |
+
|
59 |
+
with gr.Blocks() as demo:
|
60 |
+
gr.Markdown("# OCR 图像识别")
|
61 |
+
|
62 |
+
with gr.Row():
|
63 |
+
image_input = gr.Image(type="filepath", label="上传图片")
|
64 |
+
|
65 |
+
got_mode = gr.Dropdown(
|
66 |
+
choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
|
67 |
+
label="OCR模式",
|
68 |
+
value="plain texts OCR",
|
69 |
+
)
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
|
73 |
+
ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")
|
74 |
+
|
75 |
+
submit_button = gr.Button("开始OCR识别")
|
76 |
+
|
77 |
+
output = gr.HTML(label="识别结果")
|
78 |
+
|
79 |
+
submit_button.click(ocr_process, inputs=[image_input, got_mode, ocr_color, ocr_box], outputs=output)
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
demo.launch()
|
app.py
CHANGED
@@ -1,8 +1,64 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
app = FastAPI()
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from fastapi import FastAPI, File, Form, UploadFile
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
+
# 初始化模型
|
11 |
+
model_name = "ucaslcl/GOT-OCR2_0"
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
15 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
|
16 |
+
model = model.eval().to(device)
|
17 |
+
|
18 |
+
|
19 |
+
# OCR处理函数
|
20 |
+
async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""):
|
21 |
+
try:
|
22 |
+
if "plain" in got_mode:
|
23 |
+
if "multi-crop" in got_mode:
|
24 |
+
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
|
25 |
+
else:
|
26 |
+
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
|
27 |
+
return res
|
28 |
+
elif "format" in got_mode:
|
29 |
+
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
|
30 |
+
if "multi-crop" in got_mode:
|
31 |
+
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
|
32 |
+
else:
|
33 |
+
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
|
34 |
+
|
35 |
+
if os.path.exists(result_path):
|
36 |
+
with open(result_path, "r", encoding="utf-8") as f:
|
37 |
+
html_content = f.read()
|
38 |
+
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
|
39 |
+
return {"html_content": encoded_html}
|
40 |
+
|
41 |
+
return {"error": "未知的OCR模式"}
|
42 |
+
except Exception as e:
|
43 |
+
return {"error": str(e)}
|
44 |
|
45 |
+
|
46 |
+
@app.post("/ocr")
|
47 |
+
async def ocr_api(image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")):
|
48 |
+
# 保存上传的图片
|
49 |
+
image_path = f"temp_{image.filename}"
|
50 |
+
with open(image_path, "wb") as buffer:
|
51 |
+
buffer.write(await image.read())
|
52 |
+
|
53 |
+
# 处理OCR
|
54 |
+
result = await ocr_process(image_path, got_mode, ocr_color, ocr_box)
|
55 |
+
|
56 |
+
# 删除临时文件
|
57 |
+
os.remove(image_path)
|
58 |
+
|
59 |
+
return result
|
60 |
+
|
61 |
+
|
62 |
+
@app.get("/")
|
63 |
+
def read_root():
|
64 |
+
return {"message": "欢迎使用OCR API"}
|
requirements.txt
CHANGED
@@ -1,3 +1,14 @@
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
+
PyMuPDF
|
4 |
+
verovio
|
5 |
+
gradio
|
6 |
+
numpy
|
7 |
+
modelscope
|
8 |
+
Pillow
|
9 |
+
tiktoken
|
10 |
+
transformers
|
11 |
+
torch
|
12 |
+
torchvision
|
13 |
+
accelerate
|
14 |
+
spaces
|