import gradio as gr
import os
from omegaconf import OmegaConf,ListConfig
import spaces


from train import main as train_main
from inference import inference as inference_main

import transformers
transformers.utils.move_cache()


@spaces.GPU()
def inference_app(
        embedding_dir,
        prompt, 
        video_round,
        save_dir,
        motion_type,
        seed,
        inference_steps):
    
    print('inference info:')
    print('ref video:',embedding_dir)
    print('prompt:',prompt)
    print('motion type:',motion_type)
    print('infer steps:',inference_steps)

    return inference_main(
        embedding_dir=embedding_dir,
        prompt=prompt, 
        video_round=video_round,
        save_dir=save_dir,
        motion_type=motion_type,
        seed=seed,
        inference_steps=inference_steps
        )


def train_model(video, config):
    output_dir = 'results'
    os.makedirs(output_dir, exist_ok=True)
    cur_save_dir = os.path.join(output_dir, 'custom')

    config.dataset.single_video_path = video
    config.train.output_dir = cur_save_dir
    
    # copy video to cur_save_dir
    video_name = 'source.mp4'
    video_path = os.path.join(cur_save_dir, video_name)
    os.system(f"cp {video} {video_path}")

    train_main(config)
    # cur_save_dir = 'results/06'
    return cur_save_dir


def inference_model(text, checkpoint, inference_steps, video_type,seed):
    
    checkpoint = os.path.join('results',checkpoint)

    embedding_dir = '/'.join(checkpoint.split('/')[:-1])
    video_round = checkpoint.split('/')[-1]

    video_path = inference_app(
        embedding_dir=embedding_dir,
        prompt=text, 
        video_round=video_round,
        save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]),
        motion_type=video_type,
        seed=seed,
        inference_steps=inference_steps
        )

    return video_path


def get_checkpoints(checkpoint_dir):
    
    checkpoints = []
    for root, dirs, files in os.walk(checkpoint_dir):
        for file in files:
            if file == 'motion_embed.pt':
                checkpoints.append('/'.join(root.split('/')[-2:]))
    return checkpoints


def extract_combinations(motion_embeddings_combinations):
    assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required"
    combinations = []
    for combination in motion_embeddings_combinations:
        name, resolution = combination.split(" ")
        combinations.append([name, int(resolution)])
    return combinations


def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):

    default_config = OmegaConf.load('configs/config.yaml')

    default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
    default_config.model.unet = unet
    default_config.train.checkpointing_steps = checkpointing_steps
    default_config.train.max_train_steps = max_train_steps

    return default_config


def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):

    default_config = OmegaConf.load('configs/config.yaml')

    default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
    default_config.model.unet = unet
    default_config.train.checkpointing_steps = checkpointing_steps
    default_config.train.max_train_steps = max_train_steps

    return default_config


def update_preview_video(checkpoint_dir):
    # get the parent dir of the checkpoint
    parent_dir = '/'.join(checkpoint_dir.split('/')[:-1])
    return gr.update(value=f'results/{parent_dir}/source.mp4')


def update_generated_prompt(text):
    return gr.update(value=text)


if __name__ == "__main__":

    if os.path.exists('results/custom'):
        os.system('rm -rf results/custom')
    if os.path.exists('outputs'):
        os.system('rm -rf outputs')

    inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640']
    default_motion_embeddings_combinations = ['down 1280','up 1280']


    examples_inference = [
        ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'],
        ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'],
        ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'],

        ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'],
        ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'],
        ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'],
        ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], 
    ]

    gradio_theme = gr.themes.Default()
    with gr.Blocks(
        theme=gradio_theme,
        title="Motion Inversion",
        css="""
            #download {
                height: 118px;
            }
            .slider .inner {
                width: 5px;
                background: #FFF;
            }
            .viewport {
                aspect-ratio: 4/3;
            }
            .tabs button.selected {
                font-size: 20px !important;
                color: crimson !important;
            }
            h1 {
                text-align: center;
                display: block;
            }
            h2 {
                text-align: center;
                display: block;
            }
            h3 {
                text-align: center;
                display: block;
            }
            .md_feedback li {
                margin-bottom: 0px !important;
            }
        """,
        head="""
            <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
            <script>
                window.dataLayer = window.dataLayer || [];
                function gtag() {dataLayer.push(arguments);}
                gtag('js', new Date());
                gtag('config', 'G-1FWSVCGZTG');
            </script>
        """,
    ) as demo:
        
        gr.Markdown(
            """
# Motion Inversion for Video Customization
<p align="center">
<a href="https://arxiv.org/abs/2403.20193"><img src='https://img.shields.io/badge/arXiv-2403.20193-b31b1b.svg'></a>
<a href=''><img src='https://img.shields.io/badge/Project_Page-MotionInversion(Coming soon)-blue'></a>
<a href='https://github.com/EnVision-Research/MotionInversion'><img src='https://img.shields.io/github/stars/EnVision-Research/MotionInversion?label=GitHub%20%E2%98%85&logo=github&color=C8C'></a>
<br>
<strong>Please consider starring <span style="color: orange">&#9733;</span> the <a href="https://github.com/EnVision-Research/MotionInversion" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this useful!</strong>
</p>
        """
        )
        with gr.Tabs(elem_classes=["tabs"]):
            with gr.Row():
                with gr.Column():
                    preview_video = gr.Video(label="Preview Video")
                    text_input = gr.Textbox(label="Input Text")
                    checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results'))
                    seed = gr.Number(label="Seed", value=0)
                    inference_button = gr.Button("Generate Video")
                
                with gr.Column():
                    
                    output_video = gr.Video(label="Output Video")
                    generated_prompt = gr.Textbox(label="Generated Prompt")

                    with gr.Accordion('Encounter Errors', open=False):
                        gr.Markdown('''
                                    <strong>Generally, inference time for one video often takes 45~50s on ZeroGPU</strong>.

                                    <br>
                                    <strong>You have exceeded your GPU quota</strong>: A limitation set by HF. Retry in an hour.           
                                    <br>
                                    <strong>GPU task aborted</strong>: Possibly caused by ZeroGPU being used by too many people, the inference time excceeds the time limit. You may try again later, or clone the repo and run it locally. 
                                    <br>
                                    
                                    If any other issues occur, please feel free to contact us through the community or by email (ziyangmai06@gmail.com). We will try our best to help you :)

                                    ''')


            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    inference_steps = gr.Number(label="Inference Steps", value=30)
                    motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object")

        gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown])

        checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video)
        inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
        output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt)
        
        demo.queue(
            api_open=False,
        ).launch(
            server_name="0.0.0.0",
            server_port=7860,
        )