import pickle import sys import os sys.path.append(os.getcwd()) import json from glob import glob from data_utils.utils import * import torch.utils.data as data from data_utils.consts import speaker_id from data_utils.lower_body import count_part import random from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d with open('data_utils/hand_component.json') as file_obj: comp = json.load(file_obj) left_hand_c = np.asarray(comp['left']) right_hand_c = np.asarray(comp['right']) def to3d(data): left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :]) right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :]) data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1) return data class SmplxDataset(): ''' creat a dataset for every segment and concat. ''' def __init__(self, data_root, speaker, motion_fn, audio_fn, audio_sr, fps, feat_method='mel_spec', audio_feat_dim=64, audio_feat_win_size=None, train=True, load_all=False, split_trans_zero=False, limbscaling=False, num_frames=25, num_pre_frames=25, num_generate_length=25, context_info=False, convert_to_6d=False, expression=False, config=None, am=None, am_sr=None, whole_video=False ): self.data_root = data_root self.speaker = speaker self.feat_method = feat_method self.audio_fn = audio_fn self.audio_sr = audio_sr self.fps = fps self.audio_feat_dim = audio_feat_dim self.audio_feat_win_size = audio_feat_win_size self.context_info = context_info # for aud feat self.convert_to_6d = convert_to_6d self.expression = expression self.train = train self.load_all = load_all self.split_trans_zero = split_trans_zero self.limbscaling = limbscaling self.num_frames = num_frames self.num_pre_frames = num_pre_frames self.num_generate_length = num_generate_length # print('num_generate_length ', self.num_generate_length) self.config = config self.am_sr = am_sr self.whole_video = whole_video load_mode = self.config.dataset_load_mode if load_mode == 'pickle': raise NotImplementedError elif load_mode == 'csv': import pickle with open(data_root, 'rb') as f: u = pickle._Unpickler(f) data = u.load() self.data = data[0] if self.load_all: self._load_npz_all() elif load_mode == 'json': self.annotations = glob(data_root + '/*pkl') if len(self.annotations) == 0: raise FileNotFoundError(data_root + ' are empty') self.annotations = sorted(self.annotations) self.img_name_list = self.annotations if self.load_all: self._load_them_all(am, am_sr, motion_fn) def _load_npz_all(self): self.loaded_data = {} self.complete_data = [] data = self.data shape = data['body_pose_axis'].shape[0] self.betas = data['betas'] self.img_name_list = [] for index in range(shape): img_name = f'{index:6d}' self.img_name_list.append(img_name) jaw_pose = data['jaw_pose'][index] leye_pose = data['leye_pose'][index] reye_pose = data['reye_pose'][index] global_orient = data['global_orient'][index] body_pose = data['body_pose_axis'][index] left_hand_pose = data['left_hand_pose'][index] right_hand_pose = data['right_hand_pose'][index] full_body = np.concatenate( (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose)) assert full_body.shape[0] == 99 if self.convert_to_6d: full_body = to3d(full_body) full_body = torch.from_numpy(full_body) full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body)) full_body = np.asarray(full_body) if self.expression: expression = data['expression'][index] full_body = np.concatenate((full_body, expression)) # full_body = np.concatenate((full_body, non_zero)) else: full_body = to3d(full_body) if self.expression: expression = data['expression'][index] full_body = np.concatenate((full_body, expression)) self.loaded_data[img_name] = full_body.reshape(-1) self.complete_data.append(full_body.reshape(-1)) self.complete_data = np.array(self.complete_data) if self.audio_feat_win_size is not None: self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) # print(self.audio_feat.shape) else: if self.feat_method == 'mel_spec': self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim) elif self.feat_method == 'mfcc': self.audio_feat = get_mfcc(self.audio_fn, smlpx=True, sr=self.audio_sr, n_mfcc=self.audio_feat_dim, win_size=self.audio_feat_win_size ) def _load_them_all(self, am, am_sr, motion_fn): self.loaded_data = {} self.complete_data = [] f = open(motion_fn, 'rb+') data = pickle.load(f) self.betas = np.array(data['betas']) jaw_pose = np.array(data['jaw_pose']) leye_pose = np.array(data['leye_pose']) reye_pose = np.array(data['reye_pose']) global_orient = np.array(data['global_orient']).squeeze() body_pose = np.array(data['body_pose_axis']) left_hand_pose = np.array(data['left_hand_pose']) right_hand_pose = np.array(data['right_hand_pose']) full_body = np.concatenate( (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1) assert full_body.shape[1] == 99 if self.convert_to_6d: full_body = to3d(full_body) full_body = torch.from_numpy(full_body) full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330) full_body = np.asarray(full_body) if self.expression: expression = np.array(data['expression']) full_body = np.concatenate((full_body, expression), axis=1) else: full_body = to3d(full_body) expression = np.array(data['expression']) full_body = np.concatenate((full_body, expression), axis=1) self.complete_data = full_body self.complete_data = np.array(self.complete_data) if self.audio_feat_win_size is not None: self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) else: # if self.feat_method == 'mel_spec': # self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim) # elif self.feat_method == 'mfcc': self.audio_feat = get_mfcc_ta(self.audio_fn, smlpx=True, fps=30, sr=self.audio_sr, n_mfcc=self.audio_feat_dim, win_size=self.audio_feat_win_size, type=self.feat_method, am=am, am_sr=am_sr, encoder_choice=self.config.Model.encoder_choice, ) # with open(audio_file, 'w', encoding='utf-8') as file: # file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False)) def get_dataset(self, normalization=False, normalize_stats=None, split='train'): class __Worker__(data.Dataset): def __init__(child, index_list, normalization, normalize_stats, split='train') -> None: super().__init__() child.index_list = index_list child.normalization = normalization child.normalize_stats = normalize_stats child.split = split def __getitem__(child, index): num_generate_length = self.num_generate_length num_pre_frames = self.num_pre_frames seq_len = num_generate_length + num_pre_frames # print(num_generate_length) index = child.index_list[index] index_new = index + random.randrange(0, 5, 3) if index_new + seq_len > self.complete_data.shape[0]: index_new = index index = index_new if child.split in ['val', 'pre', 'test'] or self.whole_video: index = 0 seq_len = self.complete_data.shape[0] seq_data = [] assert index + seq_len <= self.complete_data.shape[0] # print(seq_len) seq_data = self.complete_data[index:(index + seq_len), :] seq_data = np.array(seq_data) ''' audio featureļ¼Œ ''' if not self.context_info: if not self.whole_video: audio_feat = self.audio_feat[index:index + seq_len, ...] if audio_feat.shape[0] < seq_len: audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]], mode='reflect') assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim else: audio_feat = self.audio_feat else: # including feature and history if self.audio_feat_win_size is None: audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...] if audio_feat.shape[0] < seq_len + num_pre_frames: audio_feat = np.pad(audio_feat, [[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]], mode='constant') assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[ 1] == self.audio_feat_dim if child.normalization: data_mean = child.normalize_stats['mean'].reshape(1, -1) data_std = child.normalize_stats['std'].reshape(1, -1) seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std if child.split in['train', 'test']: if self.convert_to_6d: if self.expression: data_sample = { 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), # 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0), 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), 'speaker': speaker_id[self.speaker], 'betas': self.betas, 'aud_file': self.audio_fn, } else: data_sample = { 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), 'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0), 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), 'speaker': speaker_id[self.speaker], 'betas': self.betas } else: if self.expression: data_sample = { 'poses': seq_data[:, :165].astype(np.float).transpose(1, 0), 'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0), 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), # 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0), 'speaker': speaker_id[self.speaker], 'aud_file': self.audio_fn, 'betas': self.betas } else: data_sample = { 'poses': seq_data.astype(np.float).transpose(1, 0), 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), 'speaker': speaker_id[self.speaker], 'betas': self.betas } return data_sample else: data_sample = { 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), # 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0), 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), 'aud_file': self.audio_fn, 'speaker': speaker_id[self.speaker], 'betas': self.betas } return data_sample def __len__(child): return len(child.index_list) if split == 'train': index_list = list( range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames, 6)) elif split in ['val', 'test']: index_list = list([0]) if self.whole_video: index_list = list([0]) self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split) def __len__(self): return len(self.img_name_list)