import os
import pdb
import pickle as pkl
from torch.utils.data import Dataset, DataLoader
# from omegaconf import DictConfig, OmegaConf
from scipy.spatial.transform import Rotation as R
from models.joints_to_smplx import joints_to_smpl, JointsToSMPLX
from utils import *
from constants import *
from datasets.trumans import TrumansDataset
from models.synhsi import Unet
# from hydra import compose, initialize
import yaml


ACT_TYPE = ['scene', 'grasp', 'artic', 'none']


def convert_trajectory(trajectory):
    trajectory_new = [[t['x'], t['y'], t['z']] for t in trajectory]
    trajectory_new = np.array(trajectory_new)
    return trajectory_new


def get_base_speed(cfg, trajectory, is_zup=True):
    trajectory_layer = trajectory
    if is_zup:
        trajectory_layer = zup_to_yup(trajectory_layer)
    trajectory2D = trajectory_layer[:, [0, 2]]
    distance = np.sum(np.linalg.norm(trajectory2D[1:] - trajectory2D[:-1], axis=1))
    speed = trajectory_layer.shape[0] // distance
    print('Base Speed:', speed, flush=True)

    return speed


def get_guidance(cfg, trajectory, samplers, act_type='none', speed=35):
    trajectory_layer = trajectory
    # trajectory_layer = zup_to_yup(trajectory_layer)
    print(trajectory_layer.shape)
    if cfg.action_type != 'pure_inter':
        #TODO
        midpoints = trajectory_layer[[0] * cfg.len_pre + list(range(0, len(trajectory_layer), speed)) + [-1] * (cfg.len_act + (1 if cfg.stay_and_act else 0))]
        # midpoints[0, 0] = 1.4164
        # midpoints[0, 2] = 2.2544
        # midpoints = trajectory_layer[[0] * cfg.len_pre + [25, 50] + [50] + [70, 90]]

    else:
        midpoints = trajectory_layer[[0] * cfg.len_pre + [0] + [0] * (cfg.len_act + 1 if cfg.stay_and_act else 0)]
    midpoints = torch.tensor(midpoints).float().to(cfg.device)
    max_step = midpoints.shape[0] - 1

    mat_init = cfg.batch_size * [np.eye(4)]
    mat_init = torch.from_numpy(np.stack(mat_init, axis=0)).float().to(cfg.device)
    print(midpoints)
    mat_init[:, 0, 3] = midpoints[0, 0]
    mat_init[:, 2, 3] = midpoints[0, 2]

    dx = midpoints[cfg.len_pre + 1, 0] - midpoints[0, 0]
    dz = midpoints[cfg.len_pre + 1, 2] - midpoints[0, 2]

    print(-np.arctan2(dx.item(), dz.item()), dx, dz)

    mat_rot_y = R.from_rotvec(np.array([0, np.arctan2(dx.item(), dz.item()), 0])).as_matrix()
    # mat_rot_y = R.from_rotvec(np.array([0, np.arctan2(-1, -2), 0])).as_matrix()
    mat_init[:, :3, :3] = torch.from_numpy(mat_rot_y).float().to(cfg.device)


    # goal_list = torch.zeros((max_step, cfg.batch_size, cfg.dataset.seq_len, 3)).float().to(cfg.device)
    goal_list = []
    action_label_list = []
    # action_label_list = torch.zeros((max_step, cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
    sampler_list = []

    if act_type == 'none':
        for s in range(max_step):
            goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
            goal[:, :] = midpoints[s + 1]
            goal_list.append(goal)
            if cfg.dataset.nb_actions > 0:
                action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
                action_label_list.append(action_label)
            else:
                action_label_list.append(None)
            sampler_list.append(samplers['body'])
    elif act_type == 'write':
        midpoints = torch.from_numpy(trajectory_layer[::4]).float().to(cfg.device)
        for s in range(midpoints.shape[0] // 16):
            goal = torch.zeros((cfg.batch_size, 16, 3)).float().to(cfg.device)
            goal[:, :] = midpoints[s * 16: (s + 1) * 16]
            goal_list.append(goal)
            if cfg.dataset.nb_actions > 0:
                action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
                action_label_list.append(action_label)
            else:
                action_label_list.append(None)
            sampler_list.append(samplers['hand'])
    elif act_type == 'scene':
        for s in range(max_step):
            goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
            goal[:, :] = midpoints[s + 1]
            goal_list.append(goal)
            sampler_list.append(samplers['body'])
            action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
            if s > max_step - cfg.len_act:
                action_label[:, :, cfg.action_id] = 1.
            action_label_list.append(action_label)
    elif act_type == 'grasp':
        for s in range(max_step):
            if s != max_step - cfg.len_act:
                goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
                goal[:, :] = midpoints[s + 1]
                goal_list.append(goal)
            else:
                # grasp_goal = zup_to_yup(np.array([[-0.32, -3.36, 0.395],
                #                                   [-0.3, -3.2, 0.395],
                #                                   [-0.25, -3, 0.394],
                #                                   [-0.147, -3.0, 0.395]])).reshape((cfg.batch_size, -1, 3))

                grasp_goal = zup_to_yup(np.array(trajectory['Object'])).reshape((cfg.batch_size, 1, 3))
                goal = torch.zeros((cfg.batch_size, 3, 3)).float().to(cfg.device)
                goal[:, :] = torch.from_numpy(grasp_goal).float().to(cfg.device)
                goal_list.append(goal)

            action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
            if s < max_step - cfg.len_act:
                sampler_list.append(samplers['body'])
            elif s == max_step - cfg.len_act:
                sampler_list.append(samplers['hand'])
            else:
                sampler_list.append(samplers['body'])
                if cfg.action_id != -1:
                    action_label[:, :, cfg.action_id] = 1.
            action_label_list.append(action_label)
    elif act_type == 'pure_inter':
        sampler_list += [samplers['body']] * (cfg.len_pre + cfg.len_act)
        goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
        goal[:, :] = midpoints[0]
        goal_list += [goal] * (cfg.len_pre + + cfg.len_act)
        action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
        action_label_list += [action_label.clone()] * cfg.len_pre
        action_label[:, :, cfg.action_id] = 2.
        action_label_list += [action_label.clone()] * cfg.len_act


    return mat_init, goal_list, action_label_list, sampler_list


def sample_step(cfg, mat, obj_locs, goal_list, action_label_list, sampler_list):
    max_step = len(goal_list)
    fixed_points = None
    fixed_frame = 2
    points_all = []
    cnt_fixed_frame = 0
    cnt_seq_len = 0


    for s in range(max_step):
        print('step', s)

        sampler = sampler_list[s]
        if s != 0:
            fixed_points = sampler.dataset.normalize_torch(transform_points(fixed_points, torch.inverse(mat)))
        else:
            if cfg.continue_last:
                method_id = cfg.method_name.split('_')[-1]
                method_name_last = cfg.method_name[:-1] + str(int(method_id) - 1)
                mat = torch.from_numpy(np.load(os.path.join(cfg.exp_dir, f'{method_name_last}_mat.npy'))).to(sampler.device)
                fixed_points = torch.from_numpy(np.load(os.path.join(cfg.exp_dir, f'{method_name_last}_fixed_points.npy'))).to(sampler.device)
                fixed_points = sampler.dataset.normalize_torch(transform_points(fixed_points, torch.inverse(mat)))

        samples, occs = sampler.p_sample_loop(fixed_points, obj_locs, mat, cfg.scene_name, goal_list[s], action_label_list[s])

        if 0 <= s < cfg.len_pre:
            cnt_fixed_frame += sampler.fixed_frame
        if 0 <= s < cfg.len_pre:
            cnt_seq_len += cfg.dataset.seq_len

        points_gene = samples[-1]
        points_gene_np = points_gene.reshape(cfg.batch_size, cfg.dataset.seq_len, -1, 3).cpu().numpy()

        if s == 0 or fixed_frame == 0:
            #TODO
            points_all.append(points_gene_np[:, fixed_frame - 1:])
        elif fixed_frame > 0:
            points_all.append(points_gene_np[:, fixed_frame:])

        # fixed_frame = 0 if s == max_step - 1 else sampler_list[s + 1].fixed_frame
        fixed_frame = sampler_list[s].fixed_frame if s == max_step - 1 else sampler_list[s + 1].fixed_frame

        pelvis_new = points_gene[:, -fixed_frame, :9].cpu().numpy().reshape(cfg.batch_size, 3, 3)
        trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], cfg.batch_size, axis=0)
        for ip, pn in enumerate(pelvis_new):
            _, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False)
            ret_t[1] = 0.0
            rot_euler = R.from_matrix(ret_R).as_euler('zxy')
            shift_euler = np.array([0, 0, rot_euler[2]])
            shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix()
            trans_mats[ip, :3, :3] = shift_rot_matrix2
            trans_mats[ip, :3, 3] = ret_t.reshape(-1)
        mat = torch.from_numpy(trans_mats).to(device=cfg.device, dtype=torch.float32)

        if fixed_frame > 0:
            fixed_points = points_gene[:, -fixed_frame:]
        if s == max_step - 1:
            print('Saved Mat and Fixed Points', flush=True)
            # np.save(os.path.join(cfg.exp_dir, f'{cfg.method_name}_mat.npy'), mat.cpu().numpy())
            # np.save(os.path.join(cfg.exp_dir, f'{cfg.method_name}_fixed_points.npy'), fixed_points.cpu().numpy())

    points_all = np.concatenate(points_all, axis=1)
    points_all = points_all[:, cnt_seq_len - cnt_fixed_frame:]

    return points_all


def sample_wrapper(trajectory, obj_locs):
    trajectory = convert_trajectory(trajectory)
    # obj_locs = {key: [data[key]['x'], data['key']['z']] for key in data.keys() if 'trajectory' not in key}

    # cfg = compose(config_name="config_sample_synhsi")
    with open('config/config_sample_synhsi.yaml') as f:
        cfg = yaml.safe_load(f)
    cfg = dotDict(cfg)


    # @hydra.main(version_base=None, config_path="../config", config_name="config_sample_synhsi")
    # def sample(cfg) -> None:
    print(cfg)

    # seed_everything(100)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model_joints_to_smplx = init_model(cfg.model.model_smplx, device=device, eval=True)
    model_joints_to_smplx = JointsToSMPLX(**cfg.model.model_smplx)
    model_joints_to_smplx.load_state_dict(torch.load(cfg.model.model_smplx.ckpt))
    model_joints_to_smplx.to(device)
    model_joints_to_smplx.eval()

    # model_body = init_model(cfg.model.synhsi_body, device=device, eval=True)
    model_body = Unet(**cfg.model.synhsi_body)
    model_body.load_state_dict(torch.load(cfg.model.synhsi_body.ckpt))
    model_body.to(device)
    model_body.eval()
    # model_hand = init_model(cfg.model.synhsi_hand, device=device, eval=True)

    # synhsi_dataset = hydra.utils.instantiate(cfg.dataset)
    synhsi_dataset = TrumansDataset(**cfg.dataset)

    sampler_body = hydra.utils.instantiate(cfg.sampler.pelvis)
    # sampler_hand = hydra.utils.instantiate(cfg.sampler.right_hand)
    sampler_body.set_dataset_and_model(synhsi_dataset, model_body)
    # sampler_hand.set_dataset_and_model(None, model_hand)

    samplers = {'body': sampler_body, 'hand': None}

    # for scene_name in ['N3OpenArea']:

    # trajectory = np.load(os.path.join(cfg.test_dir, cfg.exp_name, f'trajectories.npy'), allow_pickle=True).item()
    # cfg.scene_name = scene_name
    # cfg.action_type = trajectory['action_type']
    # if 'action_id' in trajectory.keys():
    #     cfg.action_id = trajectory['action_id']

    # GP_LAYERS = ['GP_Layer']
    method_name = cfg.method_name
    lid = 0

    base_speed = get_base_speed(cfg, trajectory, is_zup=False)

    mat, goal_list, action_label_list, sampler_list = get_guidance(cfg, trajectory, samplers, act_type=cfg.action_type,
                                                                   speed=int(0.6 * base_speed))
    points_all = sample_step(cfg, mat, obj_locs, goal_list, action_label_list, sampler_list)

    # os.makedirs(cfg.exp_dir, exist_ok=True)
    vertices = None
    for i in range(cfg.batch_size):
        keypoint_gene_torch = torch.from_numpy(points_all[i]).reshape(-1, cfg.dataset.nb_joints * 3).to(device)
        pose, transl, left_hand, right_hand, vertices = joints_to_smpl(model_joints_to_smplx, keypoint_gene_torch, cfg.dataset.joints_ind, cfg.interp_s)
        # output_data = {'transl': transl, 'body_pose': pose[:, 3:], 'global_orient': pose[:, :3],
        #                'id': 0}
        # print(output_data)
        # with open(os.path.join(cfg.exp_dir, f'{method_name}_{lid}_{i}.pkl'), 'wb') as f:
        #     pkl.dump(output_data, f)

    # vertices = np.load('/home/jiangnan/SyntheticHSI/Gradio_demo/vertices.npy', allow_pickle=True)
    # np.save('/home/jiangnan/SyntheticHSI/Gradio_demo/vertices.npy', vertices)

    return vertices.tolist()



    # v = sample()
    #
    #
    # return v




# def load_dataset_meta(cfg):
#     metas = np.load(cfg.)

# if __name__ == '__main__':
    # OmegaConf.register_resolver("times_three", times_three)
    # OmegaConf.register_new_resolver("times", lambda x, y: int(x) * int(y))
    # sample()