from .configuration_hypernet import ZettHypernetConfig
from transformers import PreTrainedModel, RobertaConfig, RobertaModel
from functools import partial

from torch import nn as nn
import torch
from torch.nn import functional as F

class Rescaler(nn.Module):
    def __init__(self, dim: int):
        super().__init__()

        self.dim = dim

        self.w = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False)
        self.b = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False)

    def __call__(self, x):
        return self.w * x + self.b


class ProjectorBlock(nn.Module):
    def __init__(self, input_dim: int, dim: int, intermediate_dim: int):
        super().__init__()

        self.input_dim = input_dim
        self.dim = dim
        self.intermediate_dim = intermediate_dim

        self.dense1 = nn.Linear(self.input_dim, self.intermediate_dim)
        self.dense2 = nn.Linear(self.intermediate_dim, self.dim)

        self.ln = nn.LayerNorm(self.dim, eps=1e-6)

    def __call__(self, x):
        h = F.gelu(
            self.dense2(F.gelu(self.dense1(x), approximate="tanh")),
            approximate="tanh",
        )
        return self.ln(h + x)


class ZettHypernet(PreTrainedModel):
    config_class = ZettHypernetConfig

    def __init__(self, config: ZettHypernetConfig):
        super().__init__(config)

        self.config = config
        self.has_separate_out_embeddings = getattr(
            self.config, "separate_out_embeddings", False
        )

        if self.config.hn_embed_lang_id:
            self.lang_embeddings = nn.Embedding(
                self.config.n_langs, self.config.hn_hidden_size
            )

        if self.has_separate_out_embeddings:
            n_in_embd = self.config.n_embd * 2
            n_out_embd = self.config.n_embd
        else:
            n_in_embd = self.config.n_embd
            n_out_embd = self.config.n_embd

        if self.config.hn_model_type == "roberta":
            config = RobertaConfig.from_pretrained(
                self.config.hn_model_name_or_path
            )
            config.num_hidden_layers = self.config.hn_n_layers
            config.hidden_size = self.config.hn_hidden_size
            config.intermediate_size = self.config.hn_intermediate_size
            if getattr(self.config, "hn_num_attention_heads", None) is None:
                self.config.hn_num_attention_heads = self.config.hn_hidden_size // 64
            config.num_attention_heads = self.config.hn_num_attention_heads
            self.embed_init_range = config.initializer_range
            module_class = partial(RobertaModel, add_pooling_layer=False)
        elif self.config.hn_model_type == "t5":
            raise NotImplementedError()

        if self.config.hn_embed_using_source_embeddings:
            # do not need to alloc embeddings since inputs_embeds is always used
            config.vocab_size = self.config.pad_token_id + 1

        if (
            self.config.hn_add_inter_token_attention
            or self.config.hn_embed_target_priors
        ):
            raise NotImplementedError()

        self.pad_token_id = self.config.pad_token_id
        assert self.pad_token_id is not None
        self.model = module_class(config)

        # need at least one embedding
        self.fallback_embeddings = nn.Embedding(
            max(self.config.hn_n_extra_tokens, 1), n_in_embd
        )

        if self.config.hn_embed_using_source_embeddings:
            self.input_projection = nn.Sequential(
                *[
                    nn.Linear(n_in_embd, self.config.hn_hidden_size),
                    ProjectorBlock(
                        self.config.hn_hidden_size,
                        self.config.hn_hidden_size,
                        self.config.hn_intermediate_size,
                    ),
                ]
            )

        if self.config.hn_single_head:
            self.output_projection = nn.Sequential(
                *[
                    ProjectorBlock(
                        self.config.hn_hidden_size,
                        self.config.hn_hidden_size,
                        self.config.hn_intermediate_size,
                    ),
                    nn.Linear(self.config.hn_hidden_size, n_in_embd),
                ]
            )
        else:
            self.output_projection = nn.Sequential(
                *[
                    ProjectorBlock(
                        self.config.hn_hidden_size,
                        self.config.hn_hidden_size,
                        self.config.hn_intermediate_size,
                    ),
                    nn.Linear(self.config.hn_hidden_size, n_out_embd),
                ]
            )
            if self.has_separate_out_embeddings:
                self.output_projection_out = nn.Sequential(
                    *[
                        ProjectorBlock(
                            self.config.hn_hidden_size,
                            self.config.hn_hidden_size,
                            self.config.hn_intermediate_size,
                        ),
                        nn.Linear(self.config.hn_hidden_size, self.config.n_embd),
                    ]
                )

        if self.config.hn_rescale_embeddings:
            self.in_scaler = Rescaler(n_in_embd)
            self.scaler = Rescaler(n_out_embd)

            if self.has_separate_out_embeddings:
                self.out_scaler = Rescaler(self.config.n_embd)

        if getattr(self.config, "hn_predict_bias", False):
            self.bias_projection = nn.Linear(self.config.hn_hidden_size, 1)

    def __call__(
        self,
        target_surface_forms,
        target_priors=None,
        source_embeddings=None,
        lang_index=None,
        deterministic: bool = True,
    ):
        if target_priors is not None:
            raise NotImplementedError()

        if not self.config.hn_embed_using_source_embeddings:
            raise NotImplementedError()

        use_fallback = target_surface_forms >= self.config.original_vocab_size

        main_ids = torch.minimum(
            target_surface_forms, torch.tensor(self.config.original_vocab_size - 1, device=self.device)
        )
        fallback_ids = torch.maximum(
            target_surface_forms - self.config.original_vocab_size, torch.tensor(0, device=self.device)
        )

        source_embeds = F.embedding(main_ids, weight=source_embeddings)

        if self.config.hn_rescale_embeddings:
            source_embeds = self.in_scaler(source_embeds)

        inputs_embeds = torch.where(
            use_fallback[..., None],
            self.fallback_embeddings(fallback_ids),
            source_embeds,
        )
        inputs_embeds = self.input_projection(inputs_embeds)
        attention_mask = target_surface_forms != self.pad_token_id

        if self.config.hn_embed_lang_id:
            lang_embedding = self.lang_embeddings(lang_index).squeeze()
            # position embed and type embed are added afterwards only in PT version so we need to subtract them here
            lang_embedding -= self.model.embeddings.token_type_embeddings(
                torch.tensor(0, device=self.device)
            ) + self.model.embeddings.position_embeddings(
                torch.tensor(attention_mask.shape[1], device=self.device)
            )

            lang_embedding = lang_embedding[None, None, :].expand(
                inputs_embeds.shape[0], -1, -1
            )

            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    lang_embedding,
                ],
                axis=1,
            )
            attention_mask = torch.cat(
                [
                    attention_mask,
                    torch.ones(lang_embedding.shape[:-1], dtype=torch.bool, device=self.device),
                ],
                axis=1,
            )

        position_ids = torch.broadcast_to(
            torch.arange(torch.atleast_2d(attention_mask).shape[-1], device=self.device),
            attention_mask.shape,
        )

        hidden_states = self.model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
        ).last_hidden_state

        if self.config.hn_concat_last_hidden_state:
            hidden_states = hidden_states.reshape(target_surface_forms.shape[0], -1)
        else:
            hidden_states = hidden_states[:, 0]

        predicted_embeddings = self.output_projection(hidden_states)

        if self.config.hn_single_head:
            predicted_embeddings_in = predicted_embeddings[..., : self.config.n_embd]

            if self.has_separate_out_embeddings:
                predicted_embeddings_out = predicted_embeddings[
                    ..., self.config.n_embd :
                ]
            else:
                predicted_embeddings_out = None
        else:
            predicted_embeddings_in = predicted_embeddings
            if self.has_separate_out_embeddings:
                predicted_embeddings_out = self.output_projection_out(hidden_states)
            else:
                predicted_embeddings_out = None

        if self.config.hn_rescale_embeddings:
            predicted_embeddings_in = self.scaler(predicted_embeddings_in)

            if predicted_embeddings_out is not None:
                predicted_embeddings_out = self.out_scaler(predicted_embeddings_out)

        if getattr(self.config, "hn_predict_bias", False):
            predicted_bias = self.bias_projection(hidden_states)[..., 0]
        else:
            predicted_bias = torch.zeros_like(
                target_surface_forms[..., 0], dtype=self.dtype
            )

        return predicted_embeddings_in, predicted_embeddings_out, predicted_bias