TalkSHOWLIVE / nets /body_ae.py
vscode69's picture
second half
99afdfe
import os
import sys
sys.path.append(os.getcwd())
from nets.base import TrainWrapperBaseClass
from nets.spg.s2glayers import Discriminator as D_S2G
from nets.spg.vqvae_1d import AE as s2g_body
import torch
import torch.optim as optim
import torch.nn.functional as F
from data_utils.lower_body import c_index, c_index_3d, c_index_6d
def separate_aa(aa):
aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
axis = F.normalize(aa[:, :, :, :3], dim=-1)
angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
return axis, angle
class TrainWrapper(TrainWrapperBaseClass):
'''
a wrapper receving a batch from data_utils and calculate loss
'''
def __init__(self, args, config):
self.args = args
self.config = config
self.device = torch.device(self.args.gpu)
self.global_step = 0
self.gan = False
self.convert_to_6d = self.config.Data.pose.convert_to_6d
self.preleng = self.config.Data.pose.pre_pose_length
self.expression = self.config.Data.pose.expression
self.epoch = 0
self.init_params()
self.num_classes = 4
self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
if self.gan:
self.discriminator = D_S2G(
pose_dim=110 + 64, pose=self.pose
).to(self.device)
else:
self.discriminator = None
if self.convert_to_6d:
self.c_index = c_index_6d
else:
self.c_index = c_index_3d
super().__init__(args, config)
def init_optimizer(self):
self.g_optimizer = optim.Adam(
self.g.parameters(),
lr=self.config.Train.learning_rate.generator_learning_rate,
betas=[0.9, 0.999]
)
def state_dict(self):
model_state = {
'g': self.g.state_dict(),
'g_optim': self.g_optimizer.state_dict(),
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
}
return model_state
def __call__(self, bat):
# assert (not self.args.infer), "infer mode"
self.global_step += 1
total_loss = None
loss_dict = {}
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
# id = bat['speaker'].to(self.device) - 20
# id = F.one_hot(id, self.num_classes)
poses = poses[:, self.c_index, :]
gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
loss = 0
loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
return total_loss, loss_dict
def vq_train(self, gt, name, model, dict, total_loss, pre=None):
x_recon = model(gt_poses=gt, pre_state=pre)
loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
# total_loss = total_loss + loss
if name == 'g':
optimizer_name = 'g_optimizer'
optimizer = getattr(self, optimizer_name)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for key in list(loss_dict.keys()):
dict[name + key] = loss_dict.get(key, 0).item()
return dict, total_loss
def get_loss(self,
pred_poses,
gt_poses,
pre=None
):
loss_dict = {}
rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
if pre is None:
f0_vel = 0
else:
v0_pr = pred_poses[:, 0] - pre[:, -1]
v0_gt = gt_poses[:, 0] - pre[:, -1]
f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
gen_loss = rec_loss + velocity_loss + f0_vel
loss_dict['rec_loss'] = rec_loss
loss_dict['velocity_loss'] = velocity_loss
# loss_dict['e_q_loss'] = e_q_loss
if pre is not None:
loss_dict['f0_vel'] = f0_vel
return gen_loss, loss_dict
def load_state_dict(self, state_dict):
self.g.load_state_dict(state_dict['g'])
def extract(self, x):
self.g.eval()
if x.shape[2] > self.full_dim:
if x.shape[2] == 239:
x = x[:, :, 102:]
x = x[:, :, self.c_index]
feat = self.g.encode(x)
return feat.transpose(1, 2), x