import os
import pickle
import random
import torch
import torch.nn as nn
import numpy as np

from torch.utils.data import Dataset
from utilities.constants import *
from utilities.device import cpu_device
from utilities.device import get_device

import json

SEQUENCE_START = 0

class VevoDataset(Dataset):
    def __init__(self, dataset_root = "./dataset/", split="train", split_ver="v1", vis_models="2d/clip_l14p", emo_model="6c_l14p", max_seq_chord=300, max_seq_video=300, random_seq=True, is_video = True):
        
        self.dataset_root       = dataset_root

        self.vevo_chord_root = os.path.join( dataset_root, "vevo_chord", "lab_v2_norm", "all")
        self.vevo_emotion_root = os.path.join( dataset_root, "vevo_emotion", emo_model, "all")
        self.vevo_motion_root = os.path.join( dataset_root, "vevo_motion", "all")
        self.vevo_scene_offset_root = os.path.join( dataset_root, "vevo_scene_offset", "all")
        self.vevo_meta_split_path = os.path.join( dataset_root, "vevo_meta", "split", split_ver, split + ".txt")
        
        self.vevo_loudness_root = os.path.join( dataset_root, "vevo_loudness", "all")
        self.vevo_note_density_root = os.path.join( dataset_root, "vevo_note_density", "all")

        self.max_seq_video    = max_seq_video
        self.max_seq_chord    = max_seq_chord
        self.random_seq = random_seq
        self.is_video = is_video

        self.vis_models_arr = vis_models.split(" ")
        self.vevo_semantic_root_list = []
        self.id_list = []

        self.emo_model = emo_model

        if IS_VIDEO:
            for i in range( len(self.vis_models_arr) ):
                p1 = self.vis_models_arr[i].split("/")[0]
                p2 = self.vis_models_arr[i].split("/")[1]
                vevo_semantic_root = os.path.join(dataset_root, "vevo_semantic" , "all" , p1, p2)
                self.vevo_semantic_root_list.append( vevo_semantic_root )
            
        with open( self.vevo_meta_split_path ) as f:
            for line in f:
                self.id_list.append(line.strip())
        
        self.data_files_chord = []      
        self.data_files_emotion = []
        self.data_files_motion = []
        self.data_files_scene_offset = []
        self.data_files_semantic_list = []

        self.data_files_loudness = []
        self.data_files_note_density = []

        for i in range(len(self.vis_models_arr)):
            self.data_files_semantic_list.append([])

        for fid in self.id_list:
            fpath_chord = os.path.join( self.vevo_chord_root, fid + ".lab" )
            fpath_emotion = os.path.join( self.vevo_emotion_root, fid + ".lab" )
            fpath_motion = os.path.join( self.vevo_motion_root, fid + ".lab" )
            fpath_scene_offset = os.path.join( self.vevo_scene_offset_root, fid + ".lab" )

            fpath_loudness = os.path.join( self.vevo_loudness_root, fid + ".lab" )
            fpath_note_density = os.path.join( self.vevo_note_density_root, fid + ".lab" )

            fpath_semantic_list = []
            for vevo_semantic_root in self.vevo_semantic_root_list:
                fpath_semantic = os.path.join( vevo_semantic_root, fid + ".npy" )
                fpath_semantic_list.append(fpath_semantic)
            
            checkFile_semantic = True
            for fpath_semantic in fpath_semantic_list:
                if not os.path.exists(fpath_semantic):
                    checkFile_semantic = False
            
            checkFile_chord = os.path.exists(fpath_chord)
            checkFile_emotion = os.path.exists(fpath_emotion)
            checkFile_motion = os.path.exists(fpath_motion)
            checkFile_scene_offset = os.path.exists(fpath_scene_offset)

            checkFile_loudness = os.path.exists(fpath_loudness)
            checkFile_note_density = os.path.exists(fpath_note_density)

            if checkFile_chord and checkFile_emotion and checkFile_motion \
                and checkFile_scene_offset and checkFile_semantic and checkFile_loudness and checkFile_note_density :

                self.data_files_chord.append(fpath_chord)
                self.data_files_emotion.append(fpath_emotion)
                self.data_files_motion.append(fpath_motion)
                self.data_files_scene_offset.append(fpath_scene_offset)

                self.data_files_loudness.append(fpath_loudness)
                self.data_files_note_density.append(fpath_note_density)

                if IS_VIDEO:
                    for i in range(len(self.vis_models_arr)):
                        self.data_files_semantic_list[i].append( fpath_semantic_list[i] )
        
        chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json")
        
        chordRootDicPath = os.path.join( dataset_root, "vevo_meta/chord_root.json")
        chordAttrDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr.json")
        
        with open(chordDicPath) as json_file:
            self.chordDic = json.load(json_file)
        
        with open(chordRootDicPath) as json_file:
            self.chordRootDic = json.load(json_file)
        
        with open(chordAttrDicPath) as json_file:
            self.chordAttrDic = json.load(json_file)
        
    def __len__(self):
        return len(self.data_files_chord)

    def __getitem__(self, idx):
        #### ---- CHORD ----- ####
        feature_chord = np.empty(self.max_seq_chord)
        feature_chord.fill(CHORD_PAD)

        feature_chordRoot = np.empty(self.max_seq_chord)
        feature_chordRoot.fill(CHORD_ROOT_PAD)
        feature_chordAttr = np.empty(self.max_seq_chord)
        feature_chordAttr.fill(CHORD_ATTR_PAD)

        key = ""
        with open(self.data_files_chord[idx], encoding = 'utf-8') as f:
            for line in f:
                line = line.strip()
                line_arr = line.split(" ")
                if line_arr[0] == "key":
                    key = line_arr[1] + " "+ line_arr[2]
                    continue
                time = line_arr[0]
                time = int(time)
                if time >= self.max_seq_chord:
                    break
                chord = line_arr[1]
                chordID = self.chordDic[chord]
                feature_chord[time] = chordID
                chord_arr = chord.split(":")

                if len(chord_arr) == 1:
                    if chord_arr[0] == "N":
                        chordRootID = self.chordRootDic["N"]
                        chordAttrID = self.chordAttrDic["N"]
                        feature_chordRoot[time] = chordRootID
                        feature_chordAttr[time] = chordAttrID
                    else:
                        chordRootID = self.chordRootDic[chord_arr[0]]
                        feature_chordRoot[time] = chordRootID
                        feature_chordAttr[time] = 1
                elif len(chord_arr) == 2:
                    chordRootID = self.chordRootDic[chord_arr[0]]
                    chordAttrID = self.chordAttrDic[chord_arr[1]]
                    feature_chordRoot[time] = chordRootID
                    feature_chordAttr[time] = chordAttrID

        if "major" in key:
            feature_key = torch.tensor([0])
        else:
            feature_key = torch.tensor([1])

        feature_chord = torch.from_numpy(feature_chord)
        feature_chord = feature_chord.to(torch.long)
        
        feature_chordRoot = torch.from_numpy(feature_chordRoot)
        feature_chordRoot = feature_chordRoot.to(torch.long)

        feature_chordAttr = torch.from_numpy(feature_chordAttr)
        feature_chordAttr = feature_chordAttr.to(torch.long)

        feature_key = feature_key.float()
        
        x = feature_chord[:self.max_seq_chord-1]
        tgt = feature_chord[1:self.max_seq_chord]

        x_root = feature_chordRoot[:self.max_seq_chord-1]
        tgt_root = feature_chordRoot[1:self.max_seq_chord]
        x_attr = feature_chordAttr[:self.max_seq_chord-1]
        tgt_attr = feature_chordAttr[1:self.max_seq_chord]

        if time < self.max_seq_chord:
            tgt[time] = CHORD_END
            tgt_root[time] = CHORD_ROOT_END
            tgt_attr[time] = CHORD_ATTR_END
        
        #### ---- SCENE OFFSET ----- ####
        feature_scene_offset = np.empty(self.max_seq_video)
        feature_scene_offset.fill(SCENE_OFFSET_PAD)
        with open(self.data_files_scene_offset[idx], encoding = 'utf-8') as f:
            for line in f:
                line = line.strip()
                line_arr = line.split(" ")
                time = line_arr[0]
                time = int(time)
                if time >= self.max_seq_chord:
                    break
                sceneID = line_arr[1]
                feature_scene_offset[time] = int(sceneID)+1

        feature_scene_offset = torch.from_numpy(feature_scene_offset)
        feature_scene_offset = feature_scene_offset.to(torch.float32)

        #### ---- MOTION ----- ####
        feature_motion = np.empty(self.max_seq_video)
        feature_motion.fill(MOTION_PAD)
        with open(self.data_files_motion[idx], encoding = 'utf-8') as f:
            for line in f:
                line = line.strip()
                line_arr = line.split(" ")
                time = line_arr[0]
                time = int(time)
                if time >= self.max_seq_chord:
                    break
                motion = line_arr[1]
                feature_motion[time] = float(motion)

        feature_motion = torch.from_numpy(feature_motion)
        feature_motion = feature_motion.to(torch.float32)

        #### ---- NOTE_DENSITY ----- ####
        feature_note_density = np.empty(self.max_seq_video)
        feature_note_density.fill(NOTE_DENSITY_PAD)
        with open(self.data_files_note_density[idx], encoding = 'utf-8') as f:
            for line in f:
                line = line.strip()
                line_arr = line.split(" ")
                time = line_arr[0]
                time = int(time)
                if time >= self.max_seq_chord:
                    break
                note_density = line_arr[1]
                feature_note_density[time] = float(note_density)

        feature_note_density = torch.from_numpy(feature_note_density)
        feature_note_density = feature_note_density.to(torch.float32)

        #### ---- LOUDNESS ----- ####
        feature_loudness = np.empty(self.max_seq_video)
        feature_loudness.fill(LOUDNESS_PAD)
        with open(self.data_files_loudness[idx], encoding = 'utf-8') as f:
            for line in f:
                line = line.strip()
                line_arr = line.split(" ")
                time = line_arr[0]
                time = int(time)
                if time >= self.max_seq_chord:
                    break
                loudness = line_arr[1]
                feature_loudness[time] = float(loudness)

        feature_loudness = torch.from_numpy(feature_loudness)
        feature_loudness = feature_loudness.to(torch.float32)

        #### ---- EMOTION ----- ####
        if self.emo_model.startswith("6c"):
            feature_emotion = np.empty( (self.max_seq_video, 6))
        else:
            feature_emotion = np.empty( (self.max_seq_video, 5))

        feature_emotion.fill(EMOTION_PAD)
        with open(self.data_files_emotion[idx], encoding = 'utf-8') as f:
            for line in f:
                line = line.strip()
                line_arr = line.split(" ")
                if line_arr[0] == "time":
                    continue
                time = line_arr[0]
                time = int(time)
                if time >= self.max_seq_chord:
                    break

                if len(line_arr) == 7:
                    emo1, emo2, emo3, emo4, emo5, emo6 = \
                        line_arr[1],line_arr[2],line_arr[3],line_arr[4],line_arr[5],line_arr[6]                    
                    emoList = [ float(emo1), float(emo2), float(emo3), float(emo4), float(emo5), float(emo6) ]
                elif len(line_arr) == 6:
                    emo1, emo2, emo3, emo4, emo5 = \
                        line_arr[1],line_arr[2],line_arr[3],line_arr[4],line_arr[5]
                    emoList = [ float(emo1), float(emo2), float(emo3), float(emo4), float(emo5) ]
                
                emoList = np.array(emoList)
                feature_emotion[time] = emoList

        feature_emotion = torch.from_numpy(feature_emotion)
        feature_emotion = feature_emotion.to(torch.float32)

        feature_emotion_argmax = torch.argmax(feature_emotion, dim=1)
        _, max_prob_indices = torch.max(feature_emotion, dim=1)
        max_prob_values = torch.gather(feature_emotion, dim=1, index=max_prob_indices.unsqueeze(1))
        max_prob_values = max_prob_values.squeeze()

        # -- emotion to chord
        #              maj dim sus4 min7 min sus2 aug dim7 maj6 hdim7 7 min6 maj7
        # 0. extcing : [1,0,1,0,0,0,0,0,0,0,1,0,0]
        # 1. fearful : [0,1,0,1,0,0,0,1,0,1,0,0,0]
        # 2. tense :   [0,1,1,1,0,0,0,0,0,0,1,0,0]
        # 3. sad :     [0,0,0,1,1,1,0,0,0,0,0,0,0]
        # 4. relaxing: [1,0,0,0,0,0,0,0,1,0,0,0,1]
        # 5. neutral : [0,0,0,0,0,0,0,0,0,0,0,0,0]

        a0 = [0]+[1,0,1,0,0,0,0,0,0,0,1,0,0]*12+[0,0]
        a1 = [0]+[0,1,0,1,0,0,0,1,0,1,0,0,0]*12+[0,0]
        a2 = [0]+[0,1,1,1,0,0,0,0,0,0,1,0,0]*12+[0,0]
        a3 = [0]+[0,0,0,1,1,1,0,0,0,0,0,0,0]*12+[0,0]
        a4 = [0]+[1,0,0,0,0,0,0,0,1,0,0,0,1]*12+[0,0]
        a5 = [0]+[0,0,0,0,0,0,0,0,0,0,0,0,0]*12+[0,0]

        aend = [0]+[0,0,0,0,0,0,0,0,0,0,0,0,0]*12+[1,0]
        apad = [0]+[0,0,0,0,0,0,0,0,0,0,0,0,0]*12+[0,1]

        a0_tensor = torch.tensor(a0)
        a1_tensor = torch.tensor(a1)
        a2_tensor = torch.tensor(a2)
        a3_tensor = torch.tensor(a3)
        a4_tensor = torch.tensor(a4)
        a5_tensor = torch.tensor(a5)

        aend_tensor = torch.tensor(aend)
        apad_tensor = torch.tensor(apad)

        mapped_tensor = torch.zeros((300, 159))
        for i, val in enumerate(feature_emotion_argmax):
            if feature_chord[i] == CHORD_PAD:
                mapped_tensor[i] = apad_tensor
            elif feature_chord[i] == CHORD_END:
                mapped_tensor[i] = aend_tensor
            elif val == 0:
                mapped_tensor[i] = a0_tensor
            elif val == 1:
                mapped_tensor[i] = a1_tensor
            elif val == 2:
                mapped_tensor[i] = a2_tensor
            elif val == 3:
                mapped_tensor[i] = a3_tensor
            elif val == 4:
                mapped_tensor[i] = a4_tensor
            elif val == 5:
                mapped_tensor[i] = a5_tensor

        # feature emotion : [1, 300, 6]
        # y : [299, 159]
        # tgt : [299]
        # tgt_emo : [299, 159]
        # tgt_emo_prob : [299]

        tgt_emotion = mapped_tensor[1:]
        tgt_emotion_prob = max_prob_values[1:]
        
        feature_semantic_list = []
        if self.is_video:
            for i in range( len(self.vis_models_arr) ):
                video_feature = np.load(self.data_files_semantic_list[i][idx])
                dim_vf = video_feature.shape[1] # 2048
                video_feature_tensor = torch.from_numpy( video_feature )
                
                feature_semantic = torch.full((self.max_seq_video, dim_vf,), SEMANTIC_PAD , dtype=torch.float32, device=cpu_device())
                if(video_feature_tensor.shape[0] < self.max_seq_video):
                    feature_semantic[:video_feature_tensor.shape[0]] = video_feature_tensor
                else:
                    feature_semantic = video_feature_tensor[:self.max_seq_video]
                feature_semantic_list.append(feature_semantic)

        return { "x":x, 
                "tgt":tgt, 
                "x_root":x_root, 
                "tgt_root":tgt_root, 
                "x_attr":x_attr, 
                "tgt_attr":tgt_attr,
                "semanticList": feature_semantic_list, 
                "key": feature_key,
                "scene_offset": feature_scene_offset,
                "motion": feature_motion,
                "emotion": feature_emotion,
                "tgt_emotion" : tgt_emotion,
                "tgt_emotion_prob" : tgt_emotion_prob,
                "note_density" : feature_note_density,
                "loudness" : feature_loudness
                }

def create_vevo_datasets(dataset_root = "./dataset", max_seq_chord=300, max_seq_video=300, vis_models="2d/clip_l14p", emo_model="6c_l14p", split_ver="v1", random_seq=True, is_video=True):

    train_dataset = VevoDataset(
        dataset_root = dataset_root, split="train", split_ver=split_ver, 
        vis_models=vis_models, emo_model =emo_model, max_seq_chord=max_seq_chord, max_seq_video=max_seq_video, 
        random_seq=random_seq, is_video = is_video )
    
    val_dataset = VevoDataset(
        dataset_root = dataset_root, split="val", split_ver=split_ver, 
        vis_models=vis_models, emo_model =emo_model, max_seq_chord=max_seq_chord, max_seq_video=max_seq_video, 
        random_seq=random_seq, is_video = is_video )
    
    test_dataset = VevoDataset(
        dataset_root = dataset_root, split="test", split_ver=split_ver, 
        vis_models=vis_models, emo_model =emo_model, max_seq_chord=max_seq_chord, max_seq_video=max_seq_video, 
        random_seq=random_seq, is_video = is_video )
    
    return train_dataset, val_dataset, test_dataset

def compute_vevo_accuracy(out, tgt):
    softmax = nn.Softmax(dim=-1)
    out = torch.argmax(softmax(out), dim=-1)

    out = out.flatten()
    tgt = tgt.flatten()

    mask = (tgt != CHORD_PAD)

    out = out[mask]
    tgt = tgt[mask]

    if(len(tgt) == 0):
        return 1.0

    num_right = (out == tgt)
    num_right = torch.sum(num_right).type(TORCH_FLOAT)

    acc = num_right / len(tgt)

    return acc

def compute_hits_k(out, tgt, k):
    softmax = nn.Softmax(dim=-1)
    out = softmax(out)
    _, topk_indices = torch.topk(out, k, dim=-1)  # Get the indices of top-k values

    tgt = tgt.flatten()

    topk_indices = torch.squeeze(topk_indices, dim = 0)

    num_right = 0 
    pt = 0
    for i, tlist in enumerate(topk_indices):
        if tgt[i] == CHORD_PAD:
            num_right += 0
        else:
            pt += 1 
            if tgt[i].item() in tlist:
                num_right += 1

    # Empty
    if len(tgt) == 0:
        return 1.0
    
    num_right = torch.tensor(num_right, dtype=torch.float32)
    hitk = num_right / pt

    return hitk

def compute_hits_k_root_attr(out_root, out_attr, tgt, k):
    softmax = nn.Softmax(dim=-1)
    out_root = softmax(out_root)
    out_attr = softmax(out_attr)

    tensor_shape = torch.Size([1, 299, 159])
    out = torch.zeros(tensor_shape)
    for i in range(out.shape[-1]):
        if i == 0 :
            out[0, :, i] = out_root[0, :, 0] * out_attr[0, :, 0] 
        elif i == 157:
            out[0, :, i] = out_root[0, :, 13] * out_attr[0, :, 14]
        elif i == 158:
            out[0, :, i] = out_root[0, :, 14] * out_attr[0, :, 15]
        else:
            rootindex =  int( (i-1)/13 ) + 1
            attrindex =  (i-1)%13 + 1
            out[0, :, i] = out_root[0, :, rootindex] * out_attr[0, :, attrindex]

    out = softmax(out)
    _, topk_indices = torch.topk(out, k, dim=-1)  # Get the indices of top-k values

    tgt = tgt.flatten()

    topk_indices = torch.squeeze(topk_indices, dim = 0)

    num_right = 0 
    pt = 0
    for i, tlist in enumerate(topk_indices):
        if tgt[i] == CHORD_PAD:
            num_right += 0
        else:
            pt += 1 
            if tgt[i].item() in tlist:
                num_right += 1

    if len(tgt) == 0:
        return 1.0
    
    num_right = torch.tensor(num_right, dtype=torch.float32)
    hitk = num_right / pt

    return hitk

def compute_vevo_correspondence(out, tgt, tgt_emotion, tgt_emotion_prob, emotion_threshold):

    tgt_emotion = tgt_emotion.squeeze()
    tgt_emotion_prob = tgt_emotion_prob.squeeze()

    dataset_root = "./dataset/"
    chordRootInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_root_inv.json")
    chordAttrInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr_inv.json")
    chordAttrDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr.json")
    
    chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json")
    chordInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_inv.json")

    with open(chordRootInvDicPath) as json_file:
        chordRootInvDic = json.load(json_file)
    with open(chordAttrDicPath) as json_file:
        chordAttrDic = json.load(json_file)
    with open(chordAttrInvDicPath) as json_file:
        chordAttrInvDic = json.load(json_file)
    with open(chordDicPath) as json_file:
        chordDic = json.load(json_file)
    with open(chordInvDicPath) as json_file:
        chordInvDic = json.load(json_file)

    softmax = nn.Softmax(dim=-1)
    out = torch.argmax(softmax(out), dim=-1)
    out = out.flatten()

    tgt = tgt.flatten()

    num_right = 0
    tgt_emotion_quality = tgt_emotion[:, 0:14]
    pt = 0 
    for i, out_element in enumerate( out ):

        all_zeros = torch.all(tgt_emotion_quality[i] == 0)
        if tgt_emotion[i][-1] == 1 or all_zeros or tgt_emotion_prob[i] < emotion_threshold:
            num_right += 0
        else:
            pt += 1
            if out_element.item() != CHORD_END and out_element.item() != CHORD_PAD:
                gen_chord = chordInvDic[ str( out_element.item() ) ]

                chord_arr = gen_chord.split(":")
                if len(chord_arr) == 1:
                    out_quality = 1
                elif len(chord_arr) == 2:
                    chordAttrID = chordAttrDic[chord_arr[1]]
                    out_quality = chordAttrID # 0:N, 1:maj ... 13:maj7

                if tgt_emotion_quality[i][out_quality] == 1:
                    num_right += 1
                    

    if(len(tgt_emotion) == 0):
        return 1.0
    
    if(pt == 0):
        return -1
    
    num_right = torch.tensor(num_right, dtype=torch.float32)
    acc = num_right / pt

    return acc

def compute_vevo_correspondence_root_attr(y_root, y_attr, tgt, tgt_emotion, tgt_emotion_prob, emotion_threshold):

    tgt_emotion = tgt_emotion.squeeze()
    tgt_emotion_prob = tgt_emotion_prob.squeeze()

    dataset_root = "./dataset/"
    chordRootInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_root_inv.json")
    chordAttrInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr_inv.json")
    chordAttrDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr.json")
    
    chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json")
    chordInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_inv.json")

    with open(chordRootInvDicPath) as json_file:
        chordRootInvDic = json.load(json_file)
    with open(chordAttrDicPath) as json_file:
        chordAttrDic = json.load(json_file)
    with open(chordAttrInvDicPath) as json_file:
        chordAttrInvDic = json.load(json_file)
    with open(chordDicPath) as json_file:
        chordDic = json.load(json_file)
    with open(chordInvDicPath) as json_file:
        chordInvDic = json.load(json_file)

    softmax = nn.Softmax(dim=-1)

    y_root = torch.argmax(softmax(y_root), dim=-1)
    y_attr = torch.argmax(softmax(y_attr), dim=-1)
    
    y_root = y_root.flatten()
    y_attr = y_attr.flatten()

    tgt = tgt.flatten()
    y = np.empty( len(tgt) )

    y.fill(CHORD_PAD)

    for i in range(len(tgt)):
        if y_root[i].item() == CHORD_ROOT_PAD or y_attr[i].item() == CHORD_ATTR_PAD:
            y[i] = CHORD_PAD
        elif y_root[i].item() == CHORD_ROOT_END or y_attr[i].item() == CHORD_ATTR_END:
            y[i] = CHORD_END
        else:
            chordRoot = chordRootInvDic[str(y_root[i].item())]
            chordAttr = chordAttrInvDic[str(y_attr[i].item())]
            if chordRoot == "N":
                y[i] = 0
            else:
                if chordAttr == "N" or chordAttr == "maj":
                    y[i] = chordDic[chordRoot]
                else:
                    chord = chordRoot + ":" + chordAttr
                    y[i] = chordDic[chord]

    y = torch.from_numpy(y)
    y = y.to(torch.long)
    y = y.to(get_device())
    y = y.flatten()

    num_right = 0
    tgt_emotion_quality = tgt_emotion[:, 0:14]
    pt = 0 
    for i, y_element in enumerate( y ):
        all_zeros = torch.all(tgt_emotion_quality[i] == 0)
        if tgt_emotion[i][-1] == 1 or all_zeros or tgt_emotion_prob[i] < emotion_threshold:
            num_right += 0
        else:
            pt += 1
            if y_element.item() != CHORD_END and y_element.item() != CHORD_PAD:
                gen_chord = chordInvDic[ str( y_element.item() ) ]
                chord_arr = gen_chord.split(":")
                if len(chord_arr) == 1:
                    y_quality = 1
                elif len(chord_arr) == 2:
                    chordAttrID = chordAttrDic[chord_arr[1]]
                    y_quality = chordAttrID # 0:N, 1:maj ... 13:maj7

                if tgt_emotion_quality[i][y_quality] == 1:
                    num_right += 1
                    
    if(len(tgt_emotion) == 0):
        return 1.0
    
    if(pt == 0):
        return -1
    
    num_right = torch.tensor(num_right, dtype=torch.float32)
    acc = num_right / pt
    return acc

def compute_vevo_accuracy_root_attr(y_root, y_attr, tgt):

    dataset_root = "./dataset/"
    chordRootInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_root_inv.json")
    chordAttrInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr_inv.json")
    chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json")
    
    with open(chordRootInvDicPath) as json_file:
        chordRootInvDic = json.load(json_file)
    with open(chordAttrInvDicPath) as json_file:
        chordAttrInvDic = json.load(json_file)
    with open(chordDicPath) as json_file:
        chordDic = json.load(json_file)

    softmax = nn.Softmax(dim=-1)

    y_root = torch.argmax(softmax(y_root), dim=-1)
    y_attr = torch.argmax(softmax(y_attr), dim=-1)
    
    y_root = y_root.flatten()
    y_attr = y_attr.flatten()

    tgt = tgt.flatten()

    mask = (tgt != CHORD_PAD)
    y = np.empty( len(tgt) )
    y.fill(CHORD_PAD)

    for i in range(len(tgt)):
        if y_root[i].item() == CHORD_ROOT_PAD or y_attr[i].item() == CHORD_ATTR_PAD:
            y[i] = CHORD_PAD
        elif y_root[i].item() == CHORD_ROOT_END or y_attr[i].item() == CHORD_ATTR_END:
            y[i] = CHORD_END
        else:
            chordRoot = chordRootInvDic[str(y_root[i].item())]
            chordAttr = chordAttrInvDic[str(y_attr[i].item())]
            if chordRoot == "N":
                y[i] = 0
            else:
                if chordAttr == "N" or chordAttr == "maj":
                    y[i] = chordDic[chordRoot]
                else:
                    chord = chordRoot + ":" + chordAttr
                    y[i] = chordDic[chord]

    y = torch.from_numpy(y)
    y = y.to(torch.long)
    y = y.to(get_device())

    y = y[mask]
    tgt = tgt[mask]

    # Empty
    if(len(tgt) == 0):
        return 1.0

    num_right = (y == tgt)
    num_right = torch.sum(num_right).type(TORCH_FLOAT)

    acc = num_right / len(tgt)

    return acc