# -*- coding: utf-8 -*-

import json
import logging
import os
import sys
import time

from transformers.trainer_callback import (ExportableState, TrainerCallback,
                                           TrainerControl, TrainerState)
from transformers.training_args import TrainingArguments


class LoggerHandler(logging.Handler):
    r"""
    Logger handler used in Web UI.
    """

    def __init__(self):
        super().__init__()
        self.log = ""

    def reset(self):
        self.log = ""

    def emit(self, record):
        if record.name == "httpx":
            return
        log_entry = self.format(record)
        self.log += log_entry
        self.log += "\n\n"


def get_logger(name: str) -> logging.Logger:
    r"""
    Gets a standard logger with a stream hander to stdout.
    """
    formatter = logging.Formatter(
        fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
    )
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(formatter)

    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.addHandler(handler)

    return logger


def reset_logging() -> None:
    r"""
    Removes basic config of root logger. (unused in script)
    """
    root = logging.getLogger()
    list(map(root.removeHandler, root.handlers))
    list(map(root.removeFilter, root.filters))


logger = get_logger(__name__)

LOG_FILE_NAME = "trainer_log.jsonl"


class LogCallback(TrainerCallback, ExportableState):
    def __init__(self, start_time: float = None, elapsed_time: float = None):

        self.start_time = time.time() if start_time is None else start_time
        self.elapsed_time = 0 if elapsed_time is None else elapsed_time
        self.last_time = self.start_time

    def on_train_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs
    ):
        r"""
        Event called at the beginning of training.
        """
        if state.is_local_process_zero:
            if not args.resume_from_checkpoint:
                self.start_time = time.time()
                self.elapsed_time = 0
            else:
                self.start_time = state.stateful_callbacks['LogCallback']['start_time']
                self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']

        if args.save_on_each_node:
            if not state.is_local_process_zero:
                return
        else:
            if not state.is_world_process_zero:
                return

        self.last_time = time.time()
        if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
            logger.warning("Previous log file in this folder will be deleted.")
            os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))

    def on_log(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        logs,
        **kwargs
    ):
        if args.save_on_each_node:
            if not state.is_local_process_zero:
                return
        else:
            if not state.is_world_process_zero:
                return

        self.elapsed_time += time.time() - self.last_time
        self.last_time = time.time()
        if 'num_input_tokens_seen' in logs:
            logs['num_tokens'] = logs.pop('num_input_tokens_seen')
            state.log_history[-1].pop('num_input_tokens_seen')
            throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
            state.log_history[-1]['throughput'] = logs['throughput'] = throughput
        state.stateful_callbacks["LogCallback"] = self.state()

        logs = dict(
            current_steps=state.global_step,
            total_steps=state.max_steps,
            loss=state.log_history[-1].get("loss", None),
            eval_loss=state.log_history[-1].get("eval_loss", None),
            predict_loss=state.log_history[-1].get("predict_loss", None),
            learning_rate=state.log_history[-1].get("learning_rate", None),
            epoch=state.log_history[-1].get("epoch", None),
            percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
        )

        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
            f.write(json.dumps(logs) + "\n")

    def state(self) -> dict:
        return {
            'start_time': self.start_time,
            'elapsed_time': self.elapsed_time
        }

    @classmethod
    def from_state(cls, state):
        return cls(state['start_time'], state['elapsed_time'])