from transformers import  AutoModel, AutoTokenizer
import pandas as pd
import torch
from torch.utils.data import Dataset
import logging
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pickle
import string
from abc import abstractmethod
import json


class AbstractMoviesRanker:
    """Abstract class for ranking items"""
    def __init__(self, df, index_matrix, score_name = "score"):
        self.df = df
        self.ids = self.df.index.values
        self.index_matrix = index_matrix
        self.score_name = score_name

    @abstractmethod
    def encode_query(self, query):
        pass
    
    def get_scores(self, encoded_query):
        return torch.mm(encoded_query, self.index_matrix.transpose(0,1))[0].tolist()
    
    def get_top_ids(self, scores, topn=6):
        ids_scores_pairs = list(zip(self.ids.tolist(), scores))
        ids_scores_pairs = sorted(ids_scores_pairs, key = lambda x:x[1], reverse = True)
        sorted_ids = [v[0] for v in ids_scores_pairs]
        sorted_scores = [v[1] for v in ids_scores_pairs]
        sorted_df = self.df.loc[sorted_ids[:topn], :]
        sorted_df.loc[:,self.score_name] = sorted_scores[:topn]
        return sorted_df

    def run_query(self, query, topn=6):
        encoded_query = self.encode_query(query)
        scores = self.get_scores(encoded_query)
        return self.get_top_ids(scores, topn)

    depunctuate = staticmethod(lambda x: x.translate(str.maketrans('','',string.punctuation)))

class SparseTfIdfRanker(AbstractMoviesRanker):
    """Sparse Ranking via TF iDF"""
    def __init__(self, df, index_matrix, vectorizer_path):
        super(SparseTfIdfRanker, self).__init__(df, index_matrix, score_name = 'tfidf-score')
        self.vectorizer = pickle.load(open(vectorizer_path, 'rb'))
        self.index_matrix = self.index_matrix.to_dense() ##For dot products

    def encode_query(self, query):
        encoded_query = torch.tensor(self.vectorizer.transform([self.depunctuate(query)]).todense(), dtype = torch.float32)
        return F.normalize(encoded_query, p=2)


class BertRanker(AbstractMoviesRanker):
    """Dense Ranking with embedding matrix"""
    def __init__(self, df, index_matrix, modelpath):
        super(BertRanker, self).__init__(df, index_matrix, score_name = "bert-score")
        self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
        self.model = AutoModel.from_pretrained(modelpath)

    def encode_query(self, query):
        tok_q = self.tokenizer(query, return_tensors="pt", padding="max_length", max_length = 128, truncation=True)
        o = self.model(**tok_q)
        encoded_query = self.mean_pooling(o, tok_q['attention_mask'])
        return F.normalize(encoded_query, p=2)

    @staticmethod
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)



class SparseDenseMoviesRanker():
    """Sparse Ranking via TF iDF, filtering a first rank, then dense ranking on these items"""
    def __init__(self, df, modelpath, bert_index, sparse_index, vectorizer_path):
        self.df =df 
        self.ids = self.df.index.values
        self.tfidf_engine = SparseTfIdfRanker(df, sparse_index, vectorizer_path)
        self.modelpath = modelpath
        self.bert_index = bert_index

    def run_query(self, query, topn=6, first_ranking=1000):
        tfidf_sorted_frame = self.tfidf_engine.run_query(query, topn=first_ranking)
        firstranking_index = self.bert_index[tfidf_sorted_frame.index.values]
        self.bert_engine = BertRanker(tfidf_sorted_frame, firstranking_index, self.modelpath)
        bert_sorted_frame = self.bert_engine.run_query(query, topn=topn)
        return bert_sorted_frame

    @classmethod
    def from_json_config(cls, jsonfile):
        with open(jsonfile) as fp:
            conf = json.loads(fp.read())

        ##Load data for ranking
        df = pd.read_pickle(conf['dataframe'])

        ##Load indices, e.g. embeddings and encoding utilities
        bert_index = torch.load(conf['bert_index'])
        sparse_index = torch.load(conf['sparse_index'])
        vectorizer_path = conf['vectorizer_path']
        modelpath = conf['modelpath']

        ##Conf for first ranking
        firstranking = conf.get('firstranking', 100)
        ranker = cls(df, modelpath, bert_index, sparse_index, vectorizer_path)
        return ranker 


if __name__=='__main__':
    
    engine = SparseDenseMoviesRanker.from_json_config('conf.json')

    for query in ["une histoire de pirates et de chasse au trésor", "une histoire de gangsters avec de l'argent"]:
        print(query)
        final_df = engine.run_query(query)
        print(final_df.head())