Spaces:
Running
on
Zero
Running
on
Zero
ABOUT = """ | |
# TB-OCR Preview 0.1 Unofficial Demo | |
This is an unofficial demo of [yifeihu/TB-OCR-preview-0.1](https://huggingface.co/yifeihu/TB-OCR-preview-0.1). | |
Overview of TB-OCR: | |
> TB-OCR-preview (Text Block OCR), created by [Yifei Hu](https://x.com/hu_yifei), is an end-to-end OCR model handling text, math latex, and markdown formats all at once. The model takes a block of text as the input and returns clean markdown output. Headers are marked with `##`. Math expressions are guaranteed to be wrapped in brackets `\( inline math \) \[ display math \]` for easier parsing. This model does not require line-detection or math formula detection. | |
(From the [model card](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)) | |
""" | |
# check out https://huggingface.co/microsoft/Phi-3.5-vision-instruct for more details | |
import torch, spaces | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from PIL import Image | |
import requests | |
import os | |
# os.system('pip install -U flash-attn') | |
model_id = "yifeihu/TB-OCR-preview-0.1" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if not torch.cuda.is_available(): | |
ABOUT += "\n\n### ⚠️ This demo is running on CPU ⚠️\n\nThis demo is running on CPU, it will be very slow. Consider duplicating it or running it locally to skip the queue and for faster response times." | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map=DEVICE, | |
trust_remote_code=True, | |
torch_dtype="auto", | |
# _attn_implementation='flash_attention_2', | |
#load_in_4bit=True # Optional: Load model in 4-bit mode to save memory | |
) | |
processor = AutoProcessor.from_pretrained(model_id, | |
trust_remote_code=True, | |
num_crops=16 | |
) | |
def phi_ocr(image_url): | |
question = "Convert the text to markdown format." | |
image = Image.open(image_url) | |
prompt_message = [{ | |
'role': 'user', | |
'content': f'<|image_1|>\n{question}', | |
}] | |
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) | |
inputs = processor(prompt, [image], return_tensors="pt").to(DEVICE) | |
generation_args = { | |
"max_new_tokens": 1024, | |
"temperature": 0.1, | |
"do_sample": False | |
} | |
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args | |
) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
response = response.split("<image_end>")[0] # remove the image_end token | |
return response | |
import gradio as gr | |
with gr.Blocks() as demo: | |
gr.Markdown(ABOUT) | |
with gr.Row(): | |
with gr.Column(): | |
img = gr.Image(label="Input image", type="filepath") | |
btn = gr.Button("OCR") | |
with gr.Column(): | |
out = gr.Markdown() | |
btn.click(phi_ocr, inputs=img, outputs=out) | |
demo.queue().launch() |