import gradio as gr
import spaces
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import numpy as np

tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True)
model = model.eval().cuda()

@spaces.GPU
def run_GOT(image_array, got_mode, ocr_box="", ocr_color=""):
    print("image_array: ", image_array)
    print(got_mode, ' ', ocr_box, ' ', ocr_color)
    # image = Image.fromarray(np.uint8(image_array))
    image = image_array
    if got_mode == "plain texts OCR":
        res = model.chat(tokenizer, image, ocr_type='ocr')
    elif got_mode == "format texts OCR":
        res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html')
    elif got_mode == "plain multi-crop OCR":
        res = model.chat_crop(tokenizer, image, ocr_type='ocr')
    elif got_mode == "format multi-crop OCR":
        res = model.chat_crop(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html')

    elif got_mode == "plain fine-grained OCR":
        res = model.chat(tokenizer, image, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
    elif got_mode == "format fine-grained OCR":
        res = model.chat(tokenizer, image, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file='./demo.html')

    if "format" in got_mode:
        with open('./demo.html', 'r') as f:
            demo_html = f.read()
        return res, demo_html
    return res, None

def task_update(task):
    if "fine-grained" in task:
        return [
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
        ]
    else:
        return [
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        ]

def fine_grained_update(task):
    if task == "box":
        return [
            gr.update(visible=False, value = ""),
            gr.update(visible=True),
        ]
    elif task == 'color':
        return [
            gr.update(visible=True),
            gr.update(visible=False, value = ""),
        ]


with gr.Blocks() as demo:
    gr.Markdown("""
    # "General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model"
    
    "🔥🔥🔥This is the official online demo of GOT-OCR-2.0 model!!!"
    
    ### Repo
    - **Hugging Face**: [ucaslcl/GOT-OCR2_0](https://huggingface.co/ucaslcl/GOT-OCR2_0)
    - **GitHub**: [Ucas-HaoranWei/GOT-OCR2_0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/)
    - **Paper**: [AriXiv](https://arxiv.org/abs/2409.01704)
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="upload your image")
            task_dropdown = gr.Dropdown(
                choices=[
                    "plain texts OCR",
                    "format texts OCR",
                    "plain multi-crop OCR",
                    "format multi-crop OCR",
                    "plain fine-grained OCR",
                    "format fine-grained OCR",
                ],
                label="Choose one mode of GOT",
                value="plain texts OCR"
            )
            fine_grained_dropdown = gr.Dropdown(
                choices=["box", "color"],
                label="fine-grained type",
                visible=False
            )
            color_dropdown = gr.Dropdown(
                choices=["red", "green", "blue"],
                label="color list",
                visible=False
            )
            box_input = gr.Textbox(
                label="input box: [x1,y1,x2,y2]",
                placeholder="e.g., [0,0,100,100]",
                visible=False
            )
            submit_button = gr.Button("Submit")
        
        with gr.Column():
            ocr_result = gr.Textbox(label="GOT output")
            html_result = gr.HTML(label="rendered html")
    
    gr.Examples(
        examples=[
            ["assets/coco.jpg", "plain texts OCR", "", "", ""],
            ["assets/en2.png", "plain texts OCR", "", "", ""],
            ["assets/eq.jpg", "format texts OCR", "", "", ""],
            ["assets/table.jpg", "format texts OCR", "", "", ""],
            ["assets/aff2.png", "plain fine-grained OCR", "box", "", "[409,763,756,891]"],
            ["assets/color.png", "plain fine-grained OCR", "color", "red", ""],
        ],
        inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
        outputs=[ocr_result, html_result],
        label="examples",
    )

    task_dropdown.change(
        task_update,
        inputs=[task_dropdown],
        outputs=[fine_grained_dropdown, color_dropdown, box_input]
    )
    fine_grained_dropdown.change(
        fine_grained_update,
        inputs=[fine_grained_dropdown],
        outputs=[color_dropdown, box_input]
    )
    
    submit_button.click(
        run_GOT,
        inputs=[image_input, task_dropdown, box_input, color_dropdown],
        outputs=[ocr_result, html_result]
    )

demo.launch(share=True)