import os os.system('pip install -U transformers==4.44.2') import sys import shutil import torch import argparse import gradio as gr import numpy as np from PIL import Image from huggingface_hub import snapshot_download # == download weights == tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny') small_model_dir = snapshot_download('wanderkid/unimernet_small') base_model_dir = snapshot_download('wanderkid/unimernet_base') os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml") os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml") os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml") # root_path = os.path.abspath(os.getcwd()) # os.makedirs(os.path.join(root_path, "models"), exist_ok=True) # shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny")) # shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small")) # shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base")) # == download weights == sys.path.insert(0, os.path.join(os.getcwd(), "..")) from unimernet.common.config import Config import unimernet.tasks as tasks from unimernet.processors import load_processor def load_model_and_processor(cfg_path): args = argparse.Namespace(cfg_path=cfg_path, options=None) cfg = Config(args) task = tasks.setup_task(cfg) model = task.build_model(cfg) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) return model, vis_processor def recognize_image(input_img, model_type): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model_type == "base": model = model_base.to(device) elif model_type == "small": model = model_small.to(device) else: model = model_tiny.to(device) if len(input_img.shape) == 3: input_img = input_img[:, :, ::-1].copy() img = Image.fromarray(input_img) image = vis_processor(img).unsqueeze(0).to(device) output = model.generate({"image": image}) latex_code = output["pred_str"][0] return latex_code def gradio_reset(): return gr.update(value=None), gr.update(value=None) if __name__ == "__main__": root_path = os.path.abspath(os.getcwd()) # == load model == model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml")) model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml")) model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml")) print("== load all models ==") # == load model == with open("header.html", "r") as file: header = file.read() with gr.Blocks() as demo: gr.HTML(header) with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=["tiny", "small", "base"], value="tiny", label="Model Type", interactive=True, ) input_img = gr.Image(label=" ", interactive=True) with gr.Row(): clear = gr.Button("Clear") predict = gr.Button(value="Recognize", interactive=True, variant="primary") with gr.Accordion("Examples:"): example_root = os.path.join(os.path.dirname(__file__), "examples") gr.Examples( examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if _.endswith("png")], inputs=input_img, ) with gr.Column(): gr.Button(value="Predict Latex:", interactive=False) pred_latex = gr.Textbox(label='Latex', interactive=False) clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex]) predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex]) demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)