TalkSHOWLIVE / data_utils /mesh_dataset.py
vscode69's picture
second half
99afdfe
raw
history blame
15.3 kB
import pickle
import sys
import os
sys.path.append(os.getcwd())
import json
from glob import glob
from data_utils.utils import *
import torch.utils.data as data
from data_utils.consts import speaker_id
from data_utils.lower_body import count_part
import random
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
with open('data_utils/hand_component.json') as file_obj:
comp = json.load(file_obj)
left_hand_c = np.asarray(comp['left'])
right_hand_c = np.asarray(comp['right'])
def to3d(data):
left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :])
right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :])
data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1)
return data
class SmplxDataset():
'''
creat a dataset for every segment and concat.
'''
def __init__(self,
data_root,
speaker,
motion_fn,
audio_fn,
audio_sr,
fps,
feat_method='mel_spec',
audio_feat_dim=64,
audio_feat_win_size=None,
train=True,
load_all=False,
split_trans_zero=False,
limbscaling=False,
num_frames=25,
num_pre_frames=25,
num_generate_length=25,
context_info=False,
convert_to_6d=False,
expression=False,
config=None,
am=None,
am_sr=None,
whole_video=False
):
self.data_root = data_root
self.speaker = speaker
self.feat_method = feat_method
self.audio_fn = audio_fn
self.audio_sr = audio_sr
self.fps = fps
self.audio_feat_dim = audio_feat_dim
self.audio_feat_win_size = audio_feat_win_size
self.context_info = context_info # for aud feat
self.convert_to_6d = convert_to_6d
self.expression = expression
self.train = train
self.load_all = load_all
self.split_trans_zero = split_trans_zero
self.limbscaling = limbscaling
self.num_frames = num_frames
self.num_pre_frames = num_pre_frames
self.num_generate_length = num_generate_length
# print('num_generate_length ', self.num_generate_length)
self.config = config
self.am_sr = am_sr
self.whole_video = whole_video
load_mode = self.config.dataset_load_mode
if load_mode == 'pickle':
raise NotImplementedError
elif load_mode == 'csv':
import pickle
with open(data_root, 'rb') as f:
u = pickle._Unpickler(f)
data = u.load()
self.data = data[0]
if self.load_all:
self._load_npz_all()
elif load_mode == 'json':
self.annotations = glob(data_root + '/*pkl')
if len(self.annotations) == 0:
raise FileNotFoundError(data_root + ' are empty')
self.annotations = sorted(self.annotations)
self.img_name_list = self.annotations
if self.load_all:
self._load_them_all(am, am_sr, motion_fn)
def _load_npz_all(self):
self.loaded_data = {}
self.complete_data = []
data = self.data
shape = data['body_pose_axis'].shape[0]
self.betas = data['betas']
self.img_name_list = []
for index in range(shape):
img_name = f'{index:6d}'
self.img_name_list.append(img_name)
jaw_pose = data['jaw_pose'][index]
leye_pose = data['leye_pose'][index]
reye_pose = data['reye_pose'][index]
global_orient = data['global_orient'][index]
body_pose = data['body_pose_axis'][index]
left_hand_pose = data['left_hand_pose'][index]
right_hand_pose = data['right_hand_pose'][index]
full_body = np.concatenate(
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose))
assert full_body.shape[0] == 99
if self.convert_to_6d:
full_body = to3d(full_body)
full_body = torch.from_numpy(full_body)
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body))
full_body = np.asarray(full_body)
if self.expression:
expression = data['expression'][index]
full_body = np.concatenate((full_body, expression))
# full_body = np.concatenate((full_body, non_zero))
else:
full_body = to3d(full_body)
if self.expression:
expression = data['expression'][index]
full_body = np.concatenate((full_body, expression))
self.loaded_data[img_name] = full_body.reshape(-1)
self.complete_data.append(full_body.reshape(-1))
self.complete_data = np.array(self.complete_data)
if self.audio_feat_win_size is not None:
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
# print(self.audio_feat.shape)
else:
if self.feat_method == 'mel_spec':
self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
elif self.feat_method == 'mfcc':
self.audio_feat = get_mfcc(self.audio_fn,
smlpx=True,
sr=self.audio_sr,
n_mfcc=self.audio_feat_dim,
win_size=self.audio_feat_win_size
)
def _load_them_all(self, am, am_sr, motion_fn):
self.loaded_data = {}
self.complete_data = []
f = open(motion_fn, 'rb+')
data = pickle.load(f)
self.betas = np.array(data['betas'])
jaw_pose = np.array(data['jaw_pose'])
leye_pose = np.array(data['leye_pose'])
reye_pose = np.array(data['reye_pose'])
global_orient = np.array(data['global_orient']).squeeze()
body_pose = np.array(data['body_pose_axis'])
left_hand_pose = np.array(data['left_hand_pose'])
right_hand_pose = np.array(data['right_hand_pose'])
full_body = np.concatenate(
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
assert full_body.shape[1] == 99
if self.convert_to_6d:
full_body = to3d(full_body)
full_body = torch.from_numpy(full_body)
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330)
full_body = np.asarray(full_body)
if self.expression:
expression = np.array(data['expression'])
full_body = np.concatenate((full_body, expression), axis=1)
else:
full_body = to3d(full_body)
expression = np.array(data['expression'])
full_body = np.concatenate((full_body, expression), axis=1)
self.complete_data = full_body
self.complete_data = np.array(self.complete_data)
if self.audio_feat_win_size is not None:
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
else:
# if self.feat_method == 'mel_spec':
# self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
# elif self.feat_method == 'mfcc':
self.audio_feat = get_mfcc_ta(self.audio_fn,
smlpx=True,
fps=30,
sr=self.audio_sr,
n_mfcc=self.audio_feat_dim,
win_size=self.audio_feat_win_size,
type=self.feat_method,
am=am,
am_sr=am_sr,
encoder_choice=self.config.Model.encoder_choice,
)
# with open(audio_file, 'w', encoding='utf-8') as file:
# file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False))
def get_dataset(self, normalization=False, normalize_stats=None, split='train'):
class __Worker__(data.Dataset):
def __init__(child, index_list, normalization, normalize_stats, split='train') -> None:
super().__init__()
child.index_list = index_list
child.normalization = normalization
child.normalize_stats = normalize_stats
child.split = split
def __getitem__(child, index):
num_generate_length = self.num_generate_length
num_pre_frames = self.num_pre_frames
seq_len = num_generate_length + num_pre_frames
# print(num_generate_length)
index = child.index_list[index]
index_new = index + random.randrange(0, 5, 3)
if index_new + seq_len > self.complete_data.shape[0]:
index_new = index
index = index_new
if child.split in ['val', 'pre', 'test'] or self.whole_video:
index = 0
seq_len = self.complete_data.shape[0]
seq_data = []
assert index + seq_len <= self.complete_data.shape[0]
# print(seq_len)
seq_data = self.complete_data[index:(index + seq_len), :]
seq_data = np.array(seq_data)
'''
audio feature,
'''
if not self.context_info:
if not self.whole_video:
audio_feat = self.audio_feat[index:index + seq_len, ...]
if audio_feat.shape[0] < seq_len:
audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]],
mode='reflect')
assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim
else:
audio_feat = self.audio_feat
else: # including feature and history
if self.audio_feat_win_size is None:
audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...]
if audio_feat.shape[0] < seq_len + num_pre_frames:
audio_feat = np.pad(audio_feat,
[[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]],
mode='constant')
assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[
1] == self.audio_feat_dim
if child.normalization:
data_mean = child.normalize_stats['mean'].reshape(1, -1)
data_std = child.normalize_stats['std'].reshape(1, -1)
seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std
if child.split in['train', 'test']:
if self.convert_to_6d:
if self.expression:
data_sample = {
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
# 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0),
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
'speaker': speaker_id[self.speaker],
'betas': self.betas,
'aud_file': self.audio_fn,
}
else:
data_sample = {
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0),
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
'speaker': speaker_id[self.speaker],
'betas': self.betas
}
else:
if self.expression:
data_sample = {
'poses': seq_data[:, :165].astype(np.float).transpose(1, 0),
'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0),
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
# 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0),
'speaker': speaker_id[self.speaker],
'aud_file': self.audio_fn,
'betas': self.betas
}
else:
data_sample = {
'poses': seq_data.astype(np.float).transpose(1, 0),
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
'speaker': speaker_id[self.speaker],
'betas': self.betas
}
return data_sample
else:
data_sample = {
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
# 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0),
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
'aud_file': self.audio_fn,
'speaker': speaker_id[self.speaker],
'betas': self.betas
}
return data_sample
def __len__(child):
return len(child.index_list)
if split == 'train':
index_list = list(
range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames,
6))
elif split in ['val', 'test']:
index_list = list([0])
if self.whole_video:
index_list = list([0])
self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split)
def __len__(self):
return len(self.img_name_list)