# """
# Author: Amir Hossein Kargaran
# Date: August, 2023

# Description: This code applies LIME (Local Interpretable Model-Agnostic Explanations) on fasttext language identification.

# MIT License

# Some part of the code is adopted from here: https://gist.github.com/ageitgey/60a8b556a9047a4ca91d6034376e5980
# """

import gradio as gr
from io import BytesIO
from fasttext.FastText import _FastText
import re
import lime.lime_text
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from selenium import webdriver
from selenium.common.exceptions import WebDriverException
import os

# Load the FastText language identification model from Hugging Face Hub
model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")

# Create the FastText classifier
classifier = _FastText(model_path)

def remove_label_prefix(item):
    return item.replace('__label__', '')

def remove_label_prefix_list(input_list):
    if isinstance(input_list[0], list):
        return [[remove_label_prefix(item) for item in inner_list] for inner_list in input_list]
    else:
        return [remove_label_prefix(item) for item in input_list]

class_names = remove_label_prefix_list(classifier.labels)
class_names = np.sort(class_names)
num_class = len(class_names)

def tokenize_string(string):
    return string.split()

explainer = lime.lime_text.LimeTextExplainer(
    split_expression=tokenize_string,
    bow=False,
    class_names=class_names
)

def fasttext_prediction_in_sklearn_format(classifier, texts):
    res = []
    labels, probabilities = classifier.predict(texts, num_class)
    labels = remove_label_prefix_list(labels)
    for label, probs, text in zip(labels, probabilities, texts):
        order = np.argsort(np.array(label))
        res.append(probs[order])
    return np.array(res)

def generate_explanation_html(input_sentence):
    preprocessed_sentence = input_sentence
    exp = explainer.explain_instance(
        preprocessed_sentence,
        classifier_fn=lambda x: fasttext_prediction_in_sklearn_format(classifier, x),
        top_labels=2,
        num_features=20,
    )
    output_html_filename = "explanation.html"
    exp.save_to_file(output_html_filename)
    return output_html_filename

def take_screenshot(local_html_path):
    options = webdriver.ChromeOptions()
    options.add_argument('--headless')
    options.add_argument('--no-sandbox')
    options.add_argument('--disable-dev-shm-usage')

    try:
        local_html_path = os.path.abspath(local_html_path)
        wd = webdriver.Chrome(options=options)
        wd.set_window_size(1366, 728)
        wd.get('file://' + local_html_path)
        wd.implicitly_wait(10)
        screenshot = wd.get_screenshot_as_png()
    except WebDriverException as e:
        return Image.new('RGB', (1, 1))
    finally:
        if wd:
            wd.quit()

    return Image.open(BytesIO(screenshot))

def merge(input_sentence):
    input_sentence = input_sentence.replace('\n', ' ')
    output_html_filename = generate_explanation_html(input_sentence)
    im = take_screenshot(output_html_filename)
    
    return im, output_html_filename

input_sentence = gr.inputs.Textbox(label="Input Sentence")

output_explanation = gr.outputs.File(label="Explanation HTML")

iface = gr.Interface(
    fn=merge,
    inputs=input_sentence,
    outputs=[gr.Image(type="pil", height=364, width=683, label = "Explanation Image"), output_explanation],
    title="LIME LID",
    description="This code applies LIME (Local Interpretable Model-Agnostic Explanations) on fasttext language identification.",
    allow_flagging='never'
)

iface.launch()