import contextlib
import os
import time
from functools import wraps
from io import StringIO
from zipfile import ZipFile

import streamlit as st
from PIL import Image

import evaluator
from yolo_dataset import YoloDataset
from yolo_model import YoloModel

fire_and_smoke = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt")
crack = YoloModel("SHOU-ISD/yolo-cracks", "best.pt")
coco = YoloModel("ultralyticsplus/yolov8s", "yolov8s.pt")


def main():
    # Header & Page Config.
    st.set_page_config(
        page_title=f"Detection",
        layout="centered")

    model = None
    with st.sidebar:
        model_choice = st.radio("Select Model", ["Fire&Smoke", "Crack", "Coco"])
        if model_choice == "Fire&Smoke":
            model = fire_and_smoke
        elif model_choice == "Crack":
            model = crack
        elif model_choice == "Coco":
            model = coco

    st.title(f"{model_choice} Detection:")

    detect_tab, evaluate_tab = st.tabs(["Detect", "Evaluate"])

    with evaluate_tab:
        evaluate(model)
    with detect_tab:
        detect(model)


def evaluate(model: YoloModel):
    buffer = st.file_uploader("Upload your Yolo Dataset here", type=["zip"])

    if buffer:
        with st.spinner('Wait for it...'):
            # Slider for changing confidence
            confidence = st.slider('Confidence Threshold', 0, 100, 30)
            yolo_dataset = YoloDataset.from_zip_file(ZipFile(buffer))
            metrics_res = capture_output(evaluator.evaluate)(model=model.model,
                                                             dataset=yolo_dataset,
                                                             confidence_threshold=confidence / 100.0)
            with metrics_res as metrics:
                st.json(metrics.speed)
                st.json(metrics.result_dict)
                for pic in os.listdir(metrics.val.save_dir):
                    st.image(os.path.join(metrics.val.save_dir, pic), use_column_width=True)


def detect(model: YoloModel):
    # This will let you upload PNG, JPG & JPEG File
    buffer = st.file_uploader("Upload your Image here", type=["jpg", "png", "jpeg"])

    if buffer:
        # Object Detecting
        with st.spinner('Wait for it...'):
            # Slider for changing confidence
            confidence = st.slider('Confidence Threshold', 0, 100, 30)

            # Calculating time for detection
            t1 = time.time()
            im = Image.open(buffer)
            # im.save("saved_images/image.jpg")
            res_img = model.preview_detect(im, confidence / 100.0)
            t2 = time.time()

        # Displaying the image
        st.image(res_img, use_column_width=True)

        # Printing Time
        st.write("\n")
        st.write("Time taken: ", t2 - t1, "sec.")


def capture_output(func):
    """Capture output from running a function and write using streamlit."""

    @wraps(func)
    def wrapper(*args, **kwargs):
        # Redirect output to string buffers
        stdout, stderr = StringIO(), StringIO()
        try:
            with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
                return func(*args, **kwargs)
        except Exception as err:
            st.write(f"Failure while executing: {err}")
        finally:
            if _stdout := stdout.getvalue():
                st.write("Execution stdout:")
                st.code(_stdout)
            if _stderr := stderr.getvalue():
                st.write("Execution stderr:")
                st.code(_stderr)

    return wrapper


if __name__ == '__main__':
    main()