import csv
import copy
import torch
import einops
import numpy as np
from torch import nn
import torch.nn.functional as F



def get_activation_fn(activation_type):
    if activation_type not in ["relu", "gelu", "glu"]:
        raise RuntimeError(f"activation function currently support relu/gelu, not {activation_type}")
    return getattr(F, activation_type)

def get_mlp_head(input_size, hidden_size, output_size, dropout=0):
    return nn.Sequential(*[
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.LayerNorm(hidden_size, eps=1e-12),
        nn.Dropout(dropout),
        nn.Linear(hidden_size, output_size)
    ])

def layer_repeat(module, N, share_layer=False):
    if share_layer:
        return nn.ModuleList([module] * N)
    else:
        return nn.ModuleList([copy.deepcopy(module) for _ in range(N - 1)] + [module])


def calc_pairwise_locs(obj_centers, obj_whls, eps=1e-10, pairwise_rel_type='center', spatial_dist_norm=True,
                       spatial_dim=5):
    if pairwise_rel_type == 'mlp':
        obj_locs = torch.cat([obj_centers, obj_whls], 2)
        pairwise_locs = torch.cat(
            [einops.repeat(obj_locs, 'b l d -> b l x d', x=obj_locs.size(1)),
             einops.repeat(obj_locs, 'b l d -> b x l d', x=obj_locs.size(1))],
            dim=3
        )
        return pairwise_locs

    pairwise_locs = einops.repeat(obj_centers, 'b l d -> b l 1 d') \
                    - einops.repeat(obj_centers, 'b l d -> b 1 l d')
    pairwise_dists = torch.sqrt(torch.sum(pairwise_locs ** 2, 3) + eps)  # (b, l, l)
    if spatial_dist_norm:
        max_dists = torch.max(pairwise_dists.view(pairwise_dists.size(0), -1), dim=1)[0]
        norm_pairwise_dists = pairwise_dists / einops.repeat(max_dists, 'b -> b 1 1')
    else:
        norm_pairwise_dists = pairwise_dists

    if spatial_dim == 1:
        return norm_pairwise_dists.unsqueeze(3)

    pairwise_dists_2d = torch.sqrt(torch.sum(pairwise_locs[..., :2] ** 2, 3) + eps)
    if pairwise_rel_type == 'center':
        pairwise_locs = torch.stack(
            [norm_pairwise_dists, pairwise_locs[..., 2] / pairwise_dists,
             pairwise_dists_2d / pairwise_dists, pairwise_locs[..., 1] / pairwise_dists_2d,
             pairwise_locs[..., 0] / pairwise_dists_2d],
            dim=3
        )
    elif pairwise_rel_type == 'vertical_bottom':
        bottom_centers = torch.clone(obj_centers)
        bottom_centers[:, :, 2] -= obj_whls[:, :, 2]
        bottom_pairwise_locs = einops.repeat(bottom_centers, 'b l d -> b l 1 d') \
                               - einops.repeat(bottom_centers, 'b l d -> b 1 l d')
        bottom_pairwise_dists = torch.sqrt(torch.sum(bottom_pairwise_locs ** 2, 3) + eps)  # (b, l, l)
        bottom_pairwise_dists_2d = torch.sqrt(torch.sum(bottom_pairwise_locs[..., :2] ** 2, 3) + eps)
        pairwise_locs = torch.stack(
            [norm_pairwise_dists,
             bottom_pairwise_locs[..., 2] / bottom_pairwise_dists,
             bottom_pairwise_dists_2d / bottom_pairwise_dists,
             pairwise_locs[..., 1] / pairwise_dists_2d,
             pairwise_locs[..., 0] / pairwise_dists_2d],
            dim=3
        )

    if spatial_dim == 4:
        pairwise_locs = pairwise_locs[..., 1:]
    return pairwise_locs

def convert_pc_to_box(obj_pc):
    xmin = np.min(obj_pc[:,0])
    ymin = np.min(obj_pc[:,1])
    zmin = np.min(obj_pc[:,2])
    xmax = np.max(obj_pc[:,0])
    ymax = np.max(obj_pc[:,1])
    zmax = np.max(obj_pc[:,2])
    center = [(xmin+xmax)/2, (ymin+ymax)/2, (zmin+zmax)/2]
    box_size = [xmax-xmin, ymax-ymin, zmax-zmin]
    return center, box_size

class LabelConverter(object):
    def __init__(self, file_path):
        self.raw_name_to_id = {}
        self.nyu40id_to_id = {}
        self.nyu40_name_to_id = {}
        self.scannet_name_to_scannet_id = {'cabinet':0, 'bed':1, 'chair':2, 'sofa':3, 'table':4,
            'door':5, 'window':6,'bookshelf':7,'picture':8, 'counter':9, 'desk':10, 'curtain':11,
            'refrigerator':12, 'shower curtain':13, 'toilet':14, 'sink':15, 'bathtub':16, 'others':17}  
        self.id_to_scannetid = {}
        self.scannet_raw_id_to_raw_name = {}
        self.raw_name_to_scannet_raw_id = {}

        with open(file_path, encoding='utf-8') as fd:
            rd = list(csv.reader(fd, delimiter="\t", quotechar='"'))
            for i in range(1, len(rd)):
                raw_id = i - 1
                scannet_raw_id = int(rd[i][0])
                raw_name = rd[i][1]
                nyu40_id = int(rd[i][4])
                nyu40_name = rd[i][7]
                self.raw_name_to_id[raw_name] = raw_id
                self.scannet_raw_id_to_raw_name[scannet_raw_id] = raw_name
                self.raw_name_to_scannet_raw_id[raw_name] = scannet_raw_id
                self.nyu40id_to_id[nyu40_id] = raw_id
                self.nyu40_name_to_id[nyu40_name] = raw_id
                if nyu40_name not in self.scannet_name_to_scannet_id:
                    self.id_to_scannetid[raw_id] = self.scannet_name_to_scannet_id['others']
                else:
                    self.id_to_scannetid[raw_id] = self.scannet_name_to_scannet_id[nyu40_name]

def build_rotate_mat(split, rot_aug=True, rand_angle='axis'):
    if rand_angle == 'random':
        theta = np.random.rand() * np.pi * 2
    else:
        ROTATE_ANGLES = [0, np.pi/2, np.pi, np.pi*3/2]
        theta_idx = np.random.randint(len(ROTATE_ANGLES))
        theta = ROTATE_ANGLES[theta_idx]
    if (theta is not None) and (theta != 0) and (split == 'train') and rot_aug:
        rot_matrix = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1]
        ], dtype=np.float32)
    else:
        rot_matrix = None
    return rot_matrix

def obj_processing_post(obj_pcds, rot_aug=True):
        obj_pcds = torch.from_numpy(obj_pcds)
        rot_matrix = build_rotate_mat('val', rot_aug)
        if rot_matrix is not None:
            rot_matrix = torch.from_numpy(rot_matrix.transpose())
            obj_pcds[:, :, :3] @= rot_matrix
        
        xyz = obj_pcds[:, :, :3]
        center = xyz.mean(1)
        xyz_min = xyz.min(1).values
        xyz_max = xyz.max(1).values
        box_center = (xyz_min + xyz_max) / 2
        size = xyz_max - xyz_min
        obj_locs = torch.cat([center, size], dim=1)
        obj_boxes = torch.cat([box_center, size], dim=1)

        # centering
        obj_pcds[:, :, :3].sub_(obj_pcds[:, :, :3].mean(1, keepdim=True))

        # normalization
        max_dist = (obj_pcds[:, :, :3]**2).sum(2).sqrt().max(1).values
        max_dist.clamp_(min=1e-6)
        obj_pcds[:, :, :3].div_(max_dist[:, None, None])
        
        return obj_pcds, obj_locs, obj_boxes, rot_matrix


def pad_sequence(sequence_list, max_len=None, pad=0, return_mask=False):
    lens = [x.shape[0] for x in sequence_list]
    if max_len is None:
        max_len = max(lens)
        
    shape = list(sequence_list[0].shape)
    shape[0] = max_len
    shape = [len(sequence_list)] + shape
    dtype = sequence_list[0].dtype
    device = sequence_list[0].device
    padded_sequence = torch.ones(shape, dtype=dtype, device=device) * pad
    for i, tensor in enumerate(sequence_list):
        padded_sequence[i, :tensor.shape[0]] = tensor
    padded_sequence = padded_sequence.to(dtype)

    if return_mask:
        mask = torch.arange(max_len).to(device)[None, :] >= torch.LongTensor(lens).to(device)[:, None] # True as masked.
        return padded_sequence, mask
    else:
        return padded_sequence