import gradio as gr
import pandas as pd
from PIL import Image
from rdkit import RDLogger
from molecule_generation_helpers import *
from property_prediction_helpers import *

RDLogger.logger().setLevel(RDLogger.ERROR)

# Predefined dataset paths (these should be adjusted to your file paths)
predefined_datasets = {
    " ": " ",
    "BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
    "ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
}

# Models
models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]

# Fusion Types
fusion_available = ["Concat"]


# Function to load a predefined dataset from the local path
def load_predefined_dataset(dataset_name):
    val = predefined_datasets.get(dataset_name)
    if val:
        df = pd.read_csv(val.split(",")[0])
        return (
            df.head(),
            gr.update(choices=list(df.columns)),
            gr.update(choices=list(df.columns)),
            dataset_name.lower(),
        )
    else:
        return (
            pd.DataFrame(),
            gr.update(choices=[]),
            gr.update(choices=[]),
            f"Dataset not found",
        )


# Function to handle dataset selection (predefined or custom)
def handle_dataset_selection(selected_dataset):
    if selected_dataset == "Custom Dataset":
        # Show file upload fields for train and test datasets if "Custom Dataset" is selected
        return (
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
        )
    return (
        gr.update(visible=True),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
    )


# Dynamically show relevant hyperparameters based on selected model
def update_hyperparameters(model_name):
    if model_name == "XGBClassifier":
        return (
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    elif model_name == "SVR":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
        )
    elif model_name == "Kernel Ridge":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
        )
    elif model_name == "Linear Regression":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    elif model_name == "Default - Auto":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )


# Function to select input and output columns and display a message
def select_columns(input_column, output_column, train_data, test_data, dataset_name):
    if input_column and output_column:
        return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
    return "Please select both input and output columns."


# Function to set Dataset Name
def set_dataname(dataset_name, dataset_selector):
    return dataset_name if dataset_selector == "Custom Dataset" else dataset_selector


# Function to display the head of the uploaded CSV file
def display_csv_head(file):
    if file is not None:
        # Load the CSV file into a DataFrame
        df = pd.read_csv(file.name)
        return (
            df.head(),
            gr.update(choices=list(df.columns)),
            gr.update(choices=list(df.columns)),
        )
    return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])


# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
smiles_image_mapping = {
    # Example SMILES for ethanol
    "Mol 1": {
        "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
        "image": "img/img1.png",
    },
    # Example SMILES for butane
    "Mol 2": {
        "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
        "image": "img/img2.png",
    },
    # Example SMILES for ethylamine
    "Mol 3": {
        "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
        "image": "img/img3.png",
    },
    # Example SMILES for diethyl ether
    "Mol 4": {
        "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
        "image": "img/img4.png",
    },
    # Example SMILES for chloroethane
    "Mol 5": {
        "smiles": "C=CCS[C@@H](C)CC(=O)OCC",
        "image": "img/img5.png",
    },
}


# Load images for selection
def load_image(path):
    try:
        return Image.open(smiles_image_mapping[path]["image"])
    except:
        pass


# Function to handle image selection
def handle_image_selection(image_key):
    if not image_key:
        return None, None
    smiles = smiles_image_mapping[image_key]["smiles"]
    mol_image = smiles_to_image(smiles)
    return smiles, mol_image


# Introduction
with open("INTRODUCTION.md") as f:
    # introduction = gr.Markdown(f.read())
    with gr.Blocks() as introduction:
        gr.Markdown(f.read())
        gr.Markdown("---\n# Debug")
        gr.HTML("HTML text: <img src='file/img/selfies-ted.png'>")
        gr.Markdown("Markdown text: ![selfies-ted](file/img/selfies-ted.png)")
        gr.HTML("HTML text: <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg'>")
        gr.Markdown("Markdown text: ![Huggingface Logo](https://huggingface.co/front/assets/huggingface_logo-noborder.svg)")

# Property Prediction
with gr.Blocks() as property_prediction:
    log_df = pd.DataFrame(
        {"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []}
    )
    state = gr.State({"log_df": log_df})
    gr.HTML(
        '''
    <p style="text-align: center">
        Task : Property Prediction
        <br>
        Models are finetuned with different combination of modalities on the uploaded or selected built data set.
    </p>
    '''
    )
    with gr.Row():
        with gr.Column():
            # Dropdown menu for predefined datasets including "Custom Dataset" option
            dataset_selector = gr.Dropdown(
                label="Select Dataset",
                choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
            )
            # Display the message for selected columns
            selected_columns_message = gr.Textbox(
                label="Selected Columns Info", visible=False
            )

            with gr.Accordion("Dataset Settings", open=True):
                # File upload options for custom dataset (train and test)
                dataset_name = gr.Textbox(label="Dataset Name", visible=False)
                train_file = gr.File(
                    label="Upload Custom Train Dataset",
                    file_types=[".csv"],
                    visible=False,
                )
                train_display = gr.Dataframe(
                    label="Train Dataset Preview (First 5 Rows)",
                    visible=False,
                    interactive=False,
                )

                test_file = gr.File(
                    label="Upload Custom Test Dataset",
                    file_types=[".csv"],
                    visible=False,
                )
                test_display = gr.Dataframe(
                    label="Test Dataset Preview (First 5 Rows)",
                    visible=False,
                    interactive=False,
                )

                # Predefined dataset displays
                predefined_display = gr.Dataframe(
                    label="Predefined Dataset Preview (First 5 Rows)",
                    visible=False,
                    interactive=False,
                )

                # Dropdowns for selecting input and output columns for the custom dataset
                input_column_selector = gr.Dropdown(
                    label="Select Input Column", choices=[], visible=False
                )
                output_column_selector = gr.Dropdown(
                    label="Select Output Column", choices=[], visible=False
                )

                # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
                dataset_selector.change(
                    handle_dataset_selection,
                    inputs=dataset_selector,
                    outputs=[
                        dataset_name,
                        train_file,
                        train_display,
                        test_file,
                        test_display,
                        predefined_display,
                        input_column_selector,
                        output_column_selector,
                    ],
                )

                # When a predefined dataset is selected, load its head and update column selectors
                dataset_selector.change(
                    load_predefined_dataset,
                    inputs=dataset_selector,
                    outputs=[
                        predefined_display,
                        input_column_selector,
                        output_column_selector,
                        selected_columns_message,
                    ],
                )

                # When a custom train file is uploaded, display its head and update column selectors
                train_file.change(
                    display_csv_head,
                    inputs=train_file,
                    outputs=[
                        train_display,
                        input_column_selector,
                        output_column_selector,
                    ],
                )

                # When a custom test file is uploaded, display its head
                test_file.change(
                    display_csv_head,
                    inputs=test_file,
                    outputs=[
                        test_display,
                        input_column_selector,
                        output_column_selector,
                    ],
                )

                dataset_selector.change(
                    set_dataname,
                    inputs=[dataset_name, dataset_selector],
                    outputs=dataset_name,
                )

                # Update the selected columns information when dropdown values are changed
                input_column_selector.change(
                    select_columns,
                    inputs=[
                        input_column_selector,
                        output_column_selector,
                        train_file,
                        test_file,
                        dataset_name,
                    ],
                    outputs=selected_columns_message,
                )

                output_column_selector.change(
                    select_columns,
                    inputs=[
                        input_column_selector,
                        output_column_selector,
                        train_file,
                        test_file,
                        dataset_name,
                    ],
                    outputs=selected_columns_message,
                )

            model_checkbox = gr.CheckboxGroup(
                choices=models_enabled, label="Select Model"
            )

            task_radiobutton = gr.Radio(
                choices=["Classification", "Regression"], label="Task Type"
            )

            ####### adding hyper parameter tuning ###########
            model_name = gr.Dropdown(
                [
                    "Default - Auto",
                    "XGBClassifier",
                    "SVR",
                    "Kernel Ridge",
                    "Linear Regression",
                ],
                label="Select Downstream Model",
            )
            with gr.Accordion("Downstream Hyperparameter Settings", open=True):
                # Create placeholders for hyperparameter components
                max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth")
                n_estimators = gr.Slider(
                    100, 5000, step=100, visible=False, label="n_estimators"
                )
                alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
                degree = gr.Slider(1, 20, step=1, visible=False, label="degree")
                kernel = gr.Dropdown(
                    choices=["rbf", "poly", "linear"], visible=False, label="kernel"
                )

                # Output textbox
                output = gr.Textbox(label="Loaded Parameters")

            # When model is selected, update which hyperparameters are visible
            model_name.change(
                update_hyperparameters,
                inputs=[model_name],
                outputs=[max_depth, n_estimators, alpha, degree, kernel],
            )

            # Submit button to create the model with selected hyperparameters
            submit_button = gr.Button("Create Downstream Model")

            # When the submit button is clicked, run the on_submit function
            submit_button.click(
                create_downstream_model,
                inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
                outputs=output,
            )
            ###### End of hyper param tuning #########

            fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")

            eval_button = gr.Button("Train downstream model")

        # Right Column
        with gr.Column():
            eval_output = gr.Textbox(label="Train downstream model")

            plot_radio = gr.Radio(
                choices=["ROC-AUC", "Parity Plot", "Latent Space"],
                label="Select Plot Type",
            )
            plot_output = gr.Plot(label="Visualization")

            create_log = gr.Button("Store log")

            log_table = gr.Dataframe(
                value=log_df, label="Log of Selections and Results", interactive=False
            )

            eval_button.click(
                display_eval,
                inputs=[
                    model_checkbox,
                    selected_columns_message,
                    task_radiobutton,
                    output,
                    fusion_radiobutton,
                    state,
                ],
                outputs=eval_output,
            )

            plot_radio.change(
                display_plot, inputs=[plot_radio, state], outputs=plot_output
            )

            create_log.click(
                evaluate_and_log,
                inputs=[
                    model_checkbox,
                    dataset_name,
                    task_radiobutton,
                    eval_output,
                    state,
                ],
                outputs=log_table,
            )


# Molecule Generation
with gr.Blocks() as molecule_generation:
    gr.HTML(
        '''
    <p style="text-align: center">
        Task : Molecule Generation
        <br>
        Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
    </p>
    '''
    )
    with gr.Row():
        with gr.Column():
            smiles_input = gr.Textbox(label="Input SMILES String")
            image_display = gr.Image(label="Molecule Image", height=250, width=250)
            # Show images for selection
            with gr.Accordion("Select from sample molecules", open=False):
                image_selector = gr.Radio(
                    choices=list(smiles_image_mapping.keys()),
                    label="Select from sample molecules",
                    value=None,
                )
                image_selector.change(load_image, image_selector, image_display)
            clear_button = gr.Button("Clear")
            generate_button = gr.Button("Submit", variant="primary")

        # Right Column
        with gr.Column():
            gen_image_display = gr.Image(
                label="Generated Molecule Image", height=250, width=250
            )
            generated_output = gr.Textbox(label="Generated Output")
            property_table = gr.Dataframe(label="Molecular Properties Comparison")

            # Handle image selection
            image_selector.change(
                handle_image_selection,
                inputs=image_selector,
                outputs=[smiles_input, image_display],
            )
            smiles_input.change(
                smiles_to_image, inputs=smiles_input, outputs=image_display
            )

            # Generate button to display canonical SMILES and molecule image
            generate_button.click(
                generate_canonical,
                inputs=smiles_input,
                outputs=[property_table, generated_output, gen_image_display],
            )
        clear_button.click(
            lambda: (None, None, None, None, None, None),
            outputs=[
                smiles_input,
                image_display,
                image_selector,
                gen_image_display,
                generated_output,
                property_table,
            ],
        )


# Render with tabs
gr.TabbedInterface(
    [introduction, property_prediction, molecule_generation],
    ["Introduction", "Property Prediction", "Molecule Generation"],
).launch(server_name="0.0.0.0", allowed_paths=["./"])