from transformers import set_seed
from tqdm.auto import trange
from PIL import Image
import numpy as np
import random
import utils
import torch


CONFIG_SPEC = [
    ("text", "A cloud at dawn", str),
    ("iterations", 5000, int),
    ("turns", 4, int),
    ("showoff", 5000, int),
    ("seed", 12, int),
    ("focal_length", 0.1, float),
    ("plane_width", 0.1, float),
    ("shade_strength", 0.25, float),
    ("gamma", 0.5, float),
    ("max_depth", 7, float),
    ("lr", 0.5, float),
    ("offset", 5, float),
    ("offset_random", 0.75, float),
    ("xyz_random", 0.25, float),
    ("altitude_range", 0.3, float),
    ("augments", 4, int),
    ("show_every", 50, int),
    ("epochs", 1, int),
    ("w", 224, int),
    ("h", 224, int),
    ("num_objects", 256, int),
    #@markdown CLIP loss type, might improve the results
    ("loss_type", "spherical", ("spherical", "cosine")),
    #@markdown CLIP loss weight
    ("clip_weight", 1.0, float),        #@param {type: "number"}
    #@markdown Number of dimensions. 0 is for point clouds (default), 1 will make
    #@markdown strokes, 2 will make planes, 3 produces little cubes
    ("ndim", 0, (0, 1, 2, 3)),  #@param {type: "integer"}

    #@markdown Opacity scale:
    ("min_opacity", 1e-4, float),       #@param {type: "number"}
    ("max_opacity", 1.0, float),        #@param {type: "number"}
    ("log_opacity", False, bool),      #@param {type: "boolean"}

    ("min_radius", 0.030, float),
    ("max_radius", 0.070, float),
    ("log_radius", False, bool),

    # TODO dynamically decide bezier_res
    #@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points
    ("bezier_res", 8, int),  #@param {type: "integer"}
    #@markdown Maximum scale of parameters: position, velocity, acceleration
    ("pos_scale", 0.4, float),  #@param {type: "number"}
    ("vel_scale", 0.15, float),  #@param {type: "number"}
    ("acc_scale", 0.15, float),  #@param {type: "number"}

    #@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale.
    ("scale", 1, float),  #@param {type: "number"}
]


# TODO: one day separate the config into multiple parts and split this megaobject into multiple objects
class PulsarCLIP(object):
    def __init__(self, args):
        args = DotDict(**args)
        set_seed(args.seed)
        self.args = args
        self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu")
        # Defer the import so that we can import `pulsar_clip` and then install `pytorch3d`
        import pytorch3d.renderer.points.pulsar as ps
        self.ndim = int(self.args.ndim)
        self.renderer = ps.Renderer(self.args.w, self.args.h,
                                    self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device)
        self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device))
        self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
        self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
        self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device))
        self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr),
                                           dict(params=[self.bezier_pos], lr=1e-1 * args.lr),
                                           dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr),
                                           ])
        self.model_clip, self.preprocess_clip = utils.load_clip()
        self.model_clip.visual.requires_grad_(False)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer,
                                                                              int(self.args.iterations
                                                                              / self.args.augments
                                                                              / self.args.epochs))
        import clip
        self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach()
        self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1)

    def get_points(self):
        if self.ndim > 0:
            bezier_ts = torch.stack(torch.meshgrid(
                (torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0
            ).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1)

        def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None):
            pos_scale = self.args.pos_scale if pos_scale is None else pos_scale
            vel_scale = self.args.vel_scale if vel_scale is None else vel_scale
            acc_scale = self.args.acc_scale if acc_scale is None else acc_scale
            scale = self.args.scale if scale is None else scale
            if self.ndim == 0:
                return pos * pos_scale
            result = 0.0
            s = pos.shape[-1]
            assert s * self.ndim == vel.shape[-1] == acc.shape[-1]
            # O(dim) sequential lol
            for d, bezier_t in zip(range(self.ndim), bezier_ts):  # TODO replace with fused dimension operation
                result = (result
                          + torch.tanh(vel[..., d * s:(d + 1) * s]).view(
                            (-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t
                          + torch.tanh(acc[..., d * s:(d + 1) * s]).view(
                            (-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2))
            result = (result * scale
                      + torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s)
            return result

        vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc)
        vert_col = interpolate_3D(self.bezier_col[..., :4],
                                  self.bezier_col[..., 4:4 + 4 * self.ndim],
                                  self.bezier_col[..., -4 * self.ndim:])

        to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat(
            (1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1])
        rescale = lambda x, a, b, is_log=False: (torch.exp(x
                                                           * np.log(b / a)
                                                           + np.log(a))) if is_log else x * (b - a) + a
        return (
            vert_pos,
            torch.sigmoid(vert_col[..., :3]),
            rescale(
                torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]),
                self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius
            ),
            rescale(torch.sigmoid(vert_col[..., -1]),
                    self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity))

    def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None,
               xyz_random=None, focal_length=None, plane_width=None):
        if offset is None:
            offset = self.args.offset
        if xyz_random is None:
            xyz_random = self.args.xyz_random
        if focal_length is None:
            focal_length = self.args.focal_length
        if plane_width is None:
            plane_width = self.args.plane_width
        if offset_random is None:
            offset_random = self.args.offset_random
        device = self.device
        offset = offset + np.random.normal() * offset_random * int(use_random)
        position = torch.tensor([0, 0, -offset], dtype=torch.float)
        position = utils.rotate_axis(position, altitude, 0)
        position = utils.rotate_axis(position, angle, 1)
        position = position + torch.randn(3) * xyz_random * int(use_random)
        return torch.tensor([position[0], position[1], position[2],
                             altitude, angle, 0,
                             focal_length, plane_width], dtype=torch.float, device=device)


    def render(self, cam_params=None):
        if cam_params is None:
            cam_params = self.camera(0, 0)
        vert_pos, vert_col, radius, opacity = self.get_points()

        rgb = self.renderer(vert_pos, vert_col, radius, cam_params,
                            self.args.gamma, self.args.max_depth, opacity=opacity)
        opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params,
                                self.args.gamma, self.args.max_depth, opacity=opacity)
        return rgb, opacity

    def random_view_render(self):
        angle = random.uniform(0, np.pi * 2)
        altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2)
        cam_params = self.camera(angle, altitude)
        result, alpha = self.render(cam_params)
        back = torch.zeros_like(result)
        s = back.shape
        for j in range(s[-1]):
            n = random.choice([7, 14, 28])
            back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5
        result = result * (1 - alpha) + back * alpha
        return result


    def generate(self):
        self.optimizer.zero_grad()
        try:
            for i in trange(self.args.iterations + self.args.showoff):
                if i < self.args.iterations:
                    result = self.random_view_render()
                    img_emb = self.model_clip.encode_image(
                        self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.))
                    img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
                    if self.args.loss_type == "spherical":
                        clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
                    elif self.args.loss_type == "cosine":
                        clip_loss = (1 - img_emb @ self.txt_emb.T).mean()
                    else:
                        raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}")
                    loss = clip_loss * self.args.clip_weight + (0 and ...)  # TODO add more loss types
                    loss.backward()
                if i % self.args.augments == self.args.augments - 1:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    try:
                        self.scheduler.step()
                    except AttributeError:
                        pass
                if i % self.args.show_every == 0:
                    cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False)
                    img_show, _ = self.render(cam_params)
                    img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8))
                    yield img
        except KeyboardInterrupt:
            pass


class DotDict(dict):
    def __getattr__(self, item):
        return self.__getitem__(item)