kosmos2_5 / ocr.py
update readme
b175204
raw
history blame
2.3 kB
import re
import torch
import requests
from PIL import Image, ImageDraw
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
repo = "microsoft/kosmos-2.5" #
repo = "kirp/kosmos2_5"
device = "cuda:0"
dtype = torch.bfloat16
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>"
inputs = processor(text=prompt, images=image, return_tensors="pt")
# batch input
# inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
def post_process(y, scale_height, scale_width):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
return info
output_text = post_process(generated_text[0], scale_height, scale_width)
print(output_text)
draw = ImageDraw.Draw(image)
lines = output_text.split("\n")
for line in lines:
# draw the bounding box
line = list(line.split(","))
if len(line) < 8:
continue
line = list(map(int, line[:8]))
draw.polygon(line, outline="red")
image.save("output.png")