from dataclasses import dataclass, field
from typing import Dict

import torch
import torch.nn.functional as F
from einops import rearrange, reduce

from ..utils import (
    BaseModule,
    chunk_batch,
    get_activation,
    rays_intersect_bbox,
    scale_tensor,
)


class TriplaneNeRFRenderer(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        radius: float

        feature_reduction: str = "concat"
        density_activation: str = "trunc_exp"
        density_bias: float = -1.0
        color_activation: str = "sigmoid"
        num_samples_per_ray: int = 128
        randomized: bool = False

    cfg: Config

    def configure(self) -> None:
        assert self.cfg.feature_reduction in ["concat", "mean"]
        self.chunk_size = 0

    def set_chunk_size(self, chunk_size: int):
        assert (
            chunk_size >= 0
        ), "chunk_size must be a non-negative integer (0 for no chunking)."
        self.chunk_size = chunk_size

    def query_triplane(
        self,
        decoder: torch.nn.Module,
        positions: torch.Tensor,
        triplane: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        input_shape = positions.shape[:-1]
        positions = positions.view(-1, 3)

        # positions in (-radius, radius)
        # normalized to (-1, 1) for grid sample
        positions = scale_tensor(
            positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
        )

        def _query_chunk(x):
            indices2D: torch.Tensor = torch.stack(
                (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
                dim=-3,
            )
            out: torch.Tensor = F.grid_sample(
                rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
                rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
                align_corners=False,
                mode="bilinear",
            )
            if self.cfg.feature_reduction == "concat":
                out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
            elif self.cfg.feature_reduction == "mean":
                out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
            else:
                raise NotImplementedError

            net_out: Dict[str, torch.Tensor] = decoder(out)
            return net_out

        if self.chunk_size > 0:
            net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
        else:
            net_out = _query_chunk(positions)

        net_out["density_act"] = get_activation(self.cfg.density_activation)(
            net_out["density"] + self.cfg.density_bias
        )
        net_out["color"] = get_activation(self.cfg.color_activation)(
            net_out["features"]
        )

        net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}

        return net_out

    def _forward(
        self,
        decoder: torch.nn.Module,
        triplane: torch.Tensor,
        rays_o: torch.Tensor,
        rays_d: torch.Tensor,
        **kwargs,
    ):
        rays_shape = rays_o.shape[:-1]
        rays_o = rays_o.view(-1, 3)
        rays_d = rays_d.view(-1, 3)
        n_rays = rays_o.shape[0]

        t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
        t_near, t_far = t_near[rays_valid], t_far[rays_valid]

        t_vals = torch.linspace(
            0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
        )
        t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
        z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None]  # (N_rays, N_samples)

        xyz = (
            rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
        )  # (N_rays, N_sample, 3)

        mlp_out = self.query_triplane(
            decoder=decoder,
            positions=xyz,
            triplane=triplane,
        )

        eps = 1e-10
        # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
        deltas = t_vals[1:] - t_vals[:-1]  # (N_rays, N_samples)
        alpha = 1 - torch.exp(
            -deltas * mlp_out["density_act"][..., 0]
        )  # (N_rays, N_samples)
        accum_prod = torch.cat(
            [
                torch.ones_like(alpha[:, :1]),
                torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
            ],
            dim=-1,
        )
        weights = alpha * accum_prod  # (N_rays, N_samples)
        comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2)  # (N_rays, 3)
        opacity_ = weights.sum(dim=-1)  # (N_rays)

        comp_rgb = torch.zeros(
            n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
        )
        opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
        comp_rgb[rays_valid] = comp_rgb_
        opacity[rays_valid] = opacity_

        comp_rgb += 1 - opacity[..., None]
        comp_rgb = comp_rgb.view(*rays_shape, 3)

        return comp_rgb

    def forward(
        self,
        decoder: torch.nn.Module,
        triplane: torch.Tensor,
        rays_o: torch.Tensor,
        rays_d: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        if triplane.ndim == 4:
            comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
        else:
            comp_rgb = torch.stack(
                [
                    self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
                    for i in range(triplane.shape[0])
                ],
                dim=0,
            )

        return comp_rgb

    def train(self, mode=True):
        self.randomized = mode and self.cfg.randomized
        return super().train(mode=mode)

    def eval(self):
        self.randomized = False
        return super().eval()