Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: [email protected] | |
import logging | |
import os | |
import os.path as osp | |
import pickle | |
from collections import namedtuple | |
from typing import Dict, Optional, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
logging.getLogger("smplx").setLevel(logging.ERROR) | |
from .lbs import find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks | |
from .utils import ( | |
Array, | |
FLAMEOutput, | |
MANOOutput, | |
SMPLHOutput, | |
SMPLOutput, | |
SMPLXOutput, | |
Struct, | |
Tensor, | |
find_joint_kin_chain, | |
to_np, | |
to_tensor, | |
) | |
from .vertex_ids import vertex_ids as VERTEX_IDS | |
from .vertex_joint_selector import VertexJointSelector | |
ModelOutput = namedtuple( | |
"ModelOutput", | |
[ | |
"vertices", | |
"joints", | |
"full_pose", | |
"betas", | |
"global_orient", | |
"body_pose", | |
"expression", | |
"left_hand_pose", | |
"right_hand_pose", | |
"jaw_pose", | |
], | |
) | |
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) | |
class SMPL(nn.Module): | |
NUM_JOINTS = 23 | |
NUM_BODY_JOINTS = 23 | |
SHAPE_SPACE_DIM = 300 | |
def __init__( | |
self, | |
model_path: str, | |
kid_template_path: str = "", | |
data_struct: Optional[Struct] = None, | |
create_betas: bool = True, | |
betas: Optional[Tensor] = None, | |
num_betas: int = 10, | |
create_global_orient: bool = True, | |
global_orient: Optional[Tensor] = None, | |
create_body_pose: bool = True, | |
body_pose: Optional[Tensor] = None, | |
create_transl: bool = True, | |
transl: Optional[Tensor] = None, | |
dtype=torch.float32, | |
batch_size: int = 1, | |
joint_mapper=None, | |
gender: str = "neutral", | |
age: str = "adult", | |
vertex_ids: Dict[str, int] = None, | |
v_template: Optional[Union[Tensor, Array]] = None, | |
v_personal: Optional[Union[Tensor, Array]] = None, | |
**kwargs, | |
) -> None: | |
"""SMPL model constructor | |
Parameters | |
---------- | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored | |
data_struct: Strct | |
A struct object. If given, then the parameters of the model are | |
read from the object. Otherwise, the model tries to read the | |
parameters from the given `model_path`. (default = None) | |
create_global_orient: bool, optional | |
Flag for creating a member variable for the global orientation | |
of the body. (default = True) | |
global_orient: torch.tensor, optional, Bx3 | |
The default value for the global orientation variable. | |
(default = None) | |
create_body_pose: bool, optional | |
Flag for creating a member variable for the pose of the body. | |
(default = True) | |
body_pose: torch.tensor, optional, Bx(Body Joints * 3) | |
The default value for the body pose variable. | |
(default = None) | |
num_betas: int, optional | |
Number of shape components to use | |
(default = 10). | |
create_betas: bool, optional | |
Flag for creating a member variable for the shape space | |
(default = True). | |
betas: torch.tensor, optional, Bx10 | |
The default value for the shape member variable. | |
(default = None) | |
create_transl: bool, optional | |
Flag for creating a member variable for the translation | |
of the body. (default = True) | |
transl: torch.tensor, optional, Bx3 | |
The default value for the transl variable. | |
(default = None) | |
dtype: torch.dtype, optional | |
The data type for the created variables | |
batch_size: int, optional | |
The batch size used for creating the member variables | |
joint_mapper: object, optional | |
An object that re-maps the joints. Useful if one wants to | |
re-order the SMPL joints to some other convention (e.g. MSCOCO) | |
(default = None) | |
gender: str, optional | |
Which gender to load | |
vertex_ids: dict, optional | |
A dictionary containing the indices of the extra vertices that | |
will be selected | |
""" | |
self.gender = gender | |
self.age = age | |
if data_struct is None: | |
if osp.isdir(model_path): | |
model_fn = "SMPL_{}.{ext}".format(gender.upper(), ext="pkl") | |
smpl_path = os.path.join(model_path, model_fn) | |
else: | |
smpl_path = model_path | |
assert osp.exists(smpl_path), "Path {} does not exist!".format(smpl_path) | |
with open(smpl_path, "rb") as smpl_file: | |
data_struct = Struct(**pickle.load(smpl_file, encoding="latin1")) | |
super(SMPL, self).__init__() | |
self.batch_size = batch_size | |
shapedirs = data_struct.shapedirs | |
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM: | |
# print(f'WARNING: You are using a {self.name()} model, with only' | |
# ' 10 shape coefficients.') | |
num_betas = min(num_betas, 10) | |
else: | |
num_betas = min(num_betas, self.SHAPE_SPACE_DIM) | |
if self.age == "kid": | |
v_template_smil = np.load(kid_template_path) | |
v_template_smil -= np.mean(v_template_smil, axis=0) | |
v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2) | |
shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2) | |
num_betas = num_betas + 1 | |
self._num_betas = num_betas | |
shapedirs = shapedirs[:, :, :num_betas] | |
# The shape components | |
self.register_buffer("shapedirs", to_tensor(to_np(shapedirs), dtype=dtype)) | |
if vertex_ids is None: | |
# SMPL and SMPL-H share the same topology, so any extra joints can | |
# be drawn from the same place | |
vertex_ids = VERTEX_IDS["smplh"] | |
self.dtype = dtype | |
self.joint_mapper = joint_mapper | |
self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, **kwargs) | |
self.faces = data_struct.f | |
self.register_buffer( | |
"faces_tensor", | |
to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long), | |
) | |
if create_betas: | |
if betas is None: | |
default_betas = torch.zeros([batch_size, self.num_betas], dtype=dtype) | |
else: | |
if torch.is_tensor(betas): | |
default_betas = betas.clone().detach() | |
else: | |
default_betas = torch.tensor(betas, dtype=dtype) | |
self.register_parameter("betas", nn.Parameter(default_betas, requires_grad=True)) | |
# The tensor that contains the global rotation of the model | |
# It is separated from the pose of the joints in case we wish to | |
# optimize only over one of them | |
if create_global_orient: | |
if global_orient is None: | |
default_global_orient = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
if torch.is_tensor(global_orient): | |
default_global_orient = global_orient.clone().detach() | |
else: | |
default_global_orient = torch.tensor(global_orient, dtype=dtype) | |
global_orient = nn.Parameter(default_global_orient, requires_grad=True) | |
self.register_parameter("global_orient", global_orient) | |
if create_body_pose: | |
if body_pose is None: | |
default_body_pose = torch.zeros([batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) | |
else: | |
if torch.is_tensor(body_pose): | |
default_body_pose = body_pose.clone().detach() | |
else: | |
default_body_pose = torch.tensor(body_pose, dtype=dtype) | |
self.register_parameter( | |
"body_pose", nn.Parameter(default_body_pose, requires_grad=True) | |
) | |
if create_transl: | |
if transl is None: | |
default_transl = torch.zeros([batch_size, 3], dtype=dtype, requires_grad=True) | |
else: | |
default_transl = torch.tensor(transl, dtype=dtype) | |
self.register_parameter("transl", nn.Parameter(default_transl, requires_grad=True)) | |
if v_template is None: | |
v_template = data_struct.v_template | |
if not torch.is_tensor(v_template): | |
v_template = to_tensor(to_np(v_template), dtype=dtype) | |
if v_personal is not None: | |
v_personal = to_tensor(to_np(v_personal), dtype=dtype) | |
v_template += v_personal | |
# The vertices of the template model | |
self.register_buffer("v_template", v_template) | |
j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype) | |
self.register_buffer("J_regressor", j_regressor) | |
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 | |
num_pose_basis = data_struct.posedirs.shape[-1] | |
# 207 x 20670 | |
posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T | |
self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=dtype)) | |
# indices of parents for each joints | |
parents = to_tensor(to_np(data_struct.kintree_table[0])).long() | |
parents[0] = -1 | |
self.register_buffer("parents", parents) | |
self.register_buffer("lbs_weights", to_tensor(to_np(data_struct.weights), dtype=dtype)) | |
def num_betas(self): | |
return self._num_betas | |
def num_expression_coeffs(self): | |
return 0 | |
def create_mean_pose(self, data_struct) -> Tensor: | |
pass | |
def name(self) -> str: | |
return "SMPL" | |
def reset_params(self, **params_dict) -> None: | |
for param_name, param in self.named_parameters(): | |
if param_name in params_dict: | |
param[:] = torch.tensor(params_dict[param_name]) | |
else: | |
param.fill_(0) | |
def get_num_verts(self) -> int: | |
return self.v_template.shape[0] | |
def get_num_faces(self) -> int: | |
return self.faces.shape[0] | |
def extra_repr(self) -> str: | |
msg = [ | |
f"Gender: {self.gender.upper()}", | |
f"Number of joints: {self.J_regressor.shape[0]}", | |
f"Betas: {self.num_betas}", | |
] | |
return "\n".join(msg) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
body_pose: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
return_verts=True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
**kwargs, | |
) -> SMPLOutput: | |
"""Forward pass for the SMPL model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable and use it as the global | |
rotation of the body. Useful if someone wishes to predicts this | |
with an external model. (default=None) | |
betas: torch.tensor, optional, shape BxN_b | |
If given, ignore the member variable `betas` and use it | |
instead. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
body_pose: torch.tensor, optional, shape Bx(J*3) | |
If given, ignore the member variable `body_pose` and use it | |
instead. For example, it can used if someone predicts the | |
pose of the body joints are predicted from some external model. | |
It should be a tensor that contains joint rotations in | |
axis-angle format. (default=None) | |
transl: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable `transl` and use it | |
instead. For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full axis-angle pose vector (default=False) | |
Returns | |
------- | |
""" | |
# If no shape and pose parameters are passed along, then use the | |
# ones from the module | |
global_orient = (global_orient if global_orient is not None else self.global_orient) | |
body_pose = body_pose if body_pose is not None else self.body_pose | |
betas = betas if betas is not None else self.betas | |
apply_trans = transl is not None or hasattr(self, "transl") | |
if transl is None and hasattr(self, "transl"): | |
transl = self.transl | |
full_pose = torch.cat([global_orient, body_pose], dim=1) | |
batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) | |
if betas.shape[0] != batch_size: | |
num_repeats = int(batch_size / betas.shape[0]) | |
betas = betas.expand(num_repeats, -1) | |
vertices, joints = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=pose2rot, | |
) | |
joints = self.vertex_joint_selector(vertices, joints) | |
# Map the joints to the current dataset | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints) | |
if apply_trans: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = SMPLOutput( | |
vertices=vertices if return_verts else None, | |
global_orient=global_orient, | |
body_pose=body_pose, | |
joints=joints, | |
betas=betas, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class SMPLLayer(SMPL): | |
def __init__(self, *args, **kwargs) -> None: | |
# Just create a SMPL module without any member variables | |
super(SMPLLayer, self).__init__( | |
create_body_pose=False, | |
create_betas=False, | |
create_global_orient=False, | |
create_transl=False, | |
*args, | |
**kwargs, | |
) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
body_pose: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
return_verts=True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
**kwargs, | |
) -> SMPLOutput: | |
"""Forward pass for the SMPL model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3x3 | |
Global rotation of the body. Useful if someone wishes to | |
predicts this with an external model. It is expected to be in | |
rotation matrix format. (default=None) | |
betas: torch.tensor, optional, shape BxN_b | |
Shape parameters. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
body_pose: torch.tensor, optional, shape BxJx3x3 | |
Body pose. For example, it can used if someone predicts the | |
pose of the body joints are predicted from some external model. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
transl: torch.tensor, optional, shape Bx3 | |
Translation vector of the body. | |
For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full axis-angle pose vector (default=False) | |
Returns | |
------- | |
""" | |
model_vars = [betas, global_orient, body_pose, transl] | |
batch_size = 1 | |
for var in model_vars: | |
if var is None: | |
continue | |
batch_size = max(batch_size, len(var)) | |
device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
if global_orient is None: | |
global_orient = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if body_pose is None: | |
body_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, | |
3).expand(batch_size, self.NUM_BODY_JOINTS, -1, | |
-1).contiguous() | |
) | |
if betas is None: | |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) | |
if transl is None: | |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) | |
full_pose = torch.cat( | |
[ | |
global_orient.reshape(-1, 1, 3, 3), | |
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), | |
], | |
dim=1, | |
) | |
vertices, joints = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=False, | |
) | |
joints = self.vertex_joint_selector(vertices, joints) | |
# Map the joints to the current dataset | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints) | |
if transl is not None: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = SMPLOutput( | |
vertices=vertices if return_verts else None, | |
global_orient=global_orient, | |
body_pose=body_pose, | |
joints=joints, | |
betas=betas, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class SMPLH(SMPL): | |
# The hand joints are replaced by MANO | |
NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 | |
NUM_HAND_JOINTS = 15 | |
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS | |
def __init__( | |
self, | |
model_path, | |
kid_template_path: str = "", | |
data_struct: Optional[Struct] = None, | |
create_left_hand_pose: bool = True, | |
left_hand_pose: Optional[Tensor] = None, | |
create_right_hand_pose: bool = True, | |
right_hand_pose: Optional[Tensor] = None, | |
use_pca: bool = True, | |
num_pca_comps: int = 6, | |
flat_hand_mean: bool = False, | |
batch_size: int = 1, | |
gender: str = "neutral", | |
age: str = "adult", | |
dtype=torch.float32, | |
vertex_ids=None, | |
use_compressed: bool = True, | |
ext: str = "pkl", | |
**kwargs, | |
) -> None: | |
"""SMPLH model constructor | |
Parameters | |
---------- | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored | |
data_struct: Strct | |
A struct object. If given, then the parameters of the model are | |
read from the object. Otherwise, the model tries to read the | |
parameters from the given `model_path`. (default = None) | |
create_left_hand_pose: bool, optional | |
Flag for creating a member variable for the pose of the left | |
hand. (default = True) | |
left_hand_pose: torch.tensor, optional, BxP | |
The default value for the left hand pose member variable. | |
(default = None) | |
create_right_hand_pose: bool, optional | |
Flag for creating a member variable for the pose of the right | |
hand. (default = True) | |
right_hand_pose: torch.tensor, optional, BxP | |
The default value for the right hand pose member variable. | |
(default = None) | |
num_pca_comps: int, optional | |
The number of PCA components to use for each hand. | |
(default = 6) | |
flat_hand_mean: bool, optional | |
If False, then the pose of the hand is initialized to False. | |
batch_size: int, optional | |
The batch size used for creating the member variables | |
gender: str, optional | |
Which gender to load | |
dtype: torch.dtype, optional | |
The data type for the created variables | |
vertex_ids: dict, optional | |
A dictionary containing the indices of the extra vertices that | |
will be selected | |
""" | |
self.num_pca_comps = num_pca_comps | |
# If no data structure is passed, then load the data from the given | |
# model folder | |
if data_struct is None: | |
# Load the model | |
if osp.isdir(model_path): | |
model_fn = "SMPLH_{}.{ext}".format(gender.upper(), ext=ext) | |
smplh_path = os.path.join(model_path, model_fn) | |
else: | |
smplh_path = model_path | |
assert osp.exists(smplh_path), "Path {} does not exist!".format(smplh_path) | |
if ext == "pkl": | |
with open(smplh_path, "rb") as smplh_file: | |
model_data = pickle.load(smplh_file, encoding="latin1") | |
elif ext == "npz": | |
model_data = np.load(smplh_path, allow_pickle=True) | |
else: | |
raise ValueError("Unknown extension: {}".format(ext)) | |
data_struct = Struct(**model_data) | |
if vertex_ids is None: | |
vertex_ids = VERTEX_IDS["smplh"] | |
super(SMPLH, self).__init__( | |
model_path=model_path, | |
kid_template_path=kid_template_path, | |
data_struct=data_struct, | |
batch_size=batch_size, | |
vertex_ids=vertex_ids, | |
gender=gender, | |
age=age, | |
use_compressed=use_compressed, | |
dtype=dtype, | |
ext=ext, | |
**kwargs, | |
) | |
self.use_pca = use_pca | |
self.num_pca_comps = num_pca_comps | |
self.flat_hand_mean = flat_hand_mean | |
left_hand_components = data_struct.hands_componentsl[:num_pca_comps] | |
right_hand_components = data_struct.hands_componentsr[:num_pca_comps] | |
self.np_left_hand_components = left_hand_components | |
self.np_right_hand_components = right_hand_components | |
if self.use_pca: | |
self.register_buffer( | |
"left_hand_components", torch.tensor(left_hand_components, dtype=dtype) | |
) | |
self.register_buffer( | |
"right_hand_components", | |
torch.tensor(right_hand_components, dtype=dtype), | |
) | |
if self.flat_hand_mean: | |
left_hand_mean = np.zeros_like(data_struct.hands_meanl) | |
else: | |
left_hand_mean = data_struct.hands_meanl | |
if self.flat_hand_mean: | |
right_hand_mean = np.zeros_like(data_struct.hands_meanr) | |
else: | |
right_hand_mean = data_struct.hands_meanr | |
self.register_buffer("left_hand_mean", to_tensor(left_hand_mean, dtype=self.dtype)) | |
self.register_buffer("right_hand_mean", to_tensor(right_hand_mean, dtype=self.dtype)) | |
# Create the buffers for the pose of the left hand | |
hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS | |
if create_left_hand_pose: | |
if left_hand_pose is None: | |
default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) | |
else: | |
default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) | |
left_hand_pose_param = nn.Parameter(default_lhand_pose, requires_grad=True) | |
self.register_parameter("left_hand_pose", left_hand_pose_param) | |
if create_right_hand_pose: | |
if right_hand_pose is None: | |
default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) | |
else: | |
default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) | |
right_hand_pose_param = nn.Parameter(default_rhand_pose, requires_grad=True) | |
self.register_parameter("right_hand_pose", right_hand_pose_param) | |
# Create the buffer for the mean pose. | |
pose_mean_tensor = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean) | |
if not torch.is_tensor(pose_mean_tensor): | |
pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) | |
self.register_buffer("pose_mean", pose_mean_tensor) | |
def create_mean_pose(self, data_struct, flat_hand_mean=False): | |
# Create the array for the mean pose. If flat_hand is false, then use | |
# the mean that is given by the data, rather than the flat open hand | |
global_orient_mean = torch.zeros([3], dtype=self.dtype) | |
body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) | |
pose_mean = torch.cat( | |
[ | |
global_orient_mean, | |
body_pose_mean, | |
self.left_hand_mean, | |
self.right_hand_mean, | |
], | |
dim=0, | |
) | |
return pose_mean | |
def name(self) -> str: | |
return "SMPL+H" | |
def extra_repr(self): | |
msg = super(SMPLH, self).extra_repr() | |
msg = [msg] | |
if self.use_pca: | |
msg.append(f"Number of PCA components: {self.num_pca_comps}") | |
msg.append(f"Flat hand mean: {self.flat_hand_mean}") | |
return "\n".join(msg) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
body_pose: Optional[Tensor] = None, | |
left_hand_pose: Optional[Tensor] = None, | |
right_hand_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
**kwargs, | |
) -> SMPLHOutput: | |
"""""" | |
# If no shape and pose parameters are passed along, then use the | |
# ones from the module | |
global_orient = (global_orient if global_orient is not None else self.global_orient) | |
body_pose = body_pose if body_pose is not None else self.body_pose | |
betas = betas if betas is not None else self.betas | |
left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) | |
right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) | |
apply_trans = transl is not None or hasattr(self, "transl") | |
if transl is None: | |
if hasattr(self, "transl"): | |
transl = self.transl | |
if self.use_pca: | |
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) | |
right_hand_pose = torch.einsum( | |
"bi,ij->bj", [right_hand_pose, self.right_hand_components] | |
) | |
full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) | |
full_pose += self.pose_mean | |
vertices, joints = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=pose2rot, | |
) | |
# Add any extra joints that might be needed | |
joints = self.vertex_joint_selector(vertices, joints) | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints) | |
if apply_trans: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = SMPLHOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints, | |
betas=betas, | |
global_orient=global_orient, | |
body_pose=body_pose, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class SMPLHLayer(SMPLH): | |
def __init__(self, *args, **kwargs) -> None: | |
"""SMPL+H as a layer model constructor""" | |
super(SMPLHLayer, self).__init__( | |
create_global_orient=False, | |
create_body_pose=False, | |
create_left_hand_pose=False, | |
create_right_hand_pose=False, | |
create_betas=False, | |
create_transl=False, | |
*args, | |
**kwargs, | |
) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
body_pose: Optional[Tensor] = None, | |
left_hand_pose: Optional[Tensor] = None, | |
right_hand_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
**kwargs, | |
) -> SMPLHOutput: | |
"""Forward pass for the SMPL+H model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3x3 | |
Global rotation of the body. Useful if someone wishes to | |
predicts this with an external model. It is expected to be in | |
rotation matrix format. (default=None) | |
betas: torch.tensor, optional, shape BxN_b | |
Shape parameters. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
body_pose: torch.tensor, optional, shape BxJx3x3 | |
If given, ignore the member variable `body_pose` and use it | |
instead. For example, it can used if someone predicts the | |
pose of the body joints are predicted from some external model. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
left_hand_pose: torch.tensor, optional, shape Bx15x3x3 | |
If given, contains the pose of the left hand. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
right_hand_pose: torch.tensor, optional, shape Bx15x3x3 | |
If given, contains the pose of the right hand. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
transl: torch.tensor, optional, shape Bx3 | |
Translation vector of the body. | |
For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full axis-angle pose vector (default=False) | |
Returns | |
------- | |
""" | |
model_vars = [ | |
betas, | |
global_orient, | |
body_pose, | |
transl, | |
left_hand_pose, | |
right_hand_pose, | |
] | |
batch_size = 1 | |
for var in model_vars: | |
if var is None: | |
continue | |
batch_size = max(batch_size, len(var)) | |
device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
if global_orient is None: | |
global_orient = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if body_pose is None: | |
body_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() | |
) | |
if left_hand_pose is None: | |
left_hand_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() | |
) | |
if right_hand_pose is None: | |
right_hand_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() | |
) | |
if betas is None: | |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) | |
if transl is None: | |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) | |
# Concatenate all pose vectors | |
full_pose = torch.cat( | |
[ | |
global_orient.reshape(-1, 1, 3, 3), | |
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), | |
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), | |
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), | |
], | |
dim=1, | |
) | |
vertices, joints = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=False, | |
) | |
# Add any extra joints that might be needed | |
joints = self.vertex_joint_selector(vertices, joints) | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints) | |
if transl is not None: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = SMPLHOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints, | |
betas=betas, | |
global_orient=global_orient, | |
body_pose=body_pose, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class SMPLX(SMPLH): | |
""" | |
SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters | |
trained jointly for the face, hands and body. | |
SMPL-X uses standard vertex based linear blend skinning with learned | |
corrective blend shapes, has N=10475 vertices and K=54 joints, | |
which includes joints for the neck, jaw, eyeballs and fingers. | |
""" | |
NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21 | |
NUM_HAND_JOINTS = 15 | |
NUM_FACE_JOINTS = 3 | |
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS | |
EXPRESSION_SPACE_DIM = 100 | |
NECK_IDX = 12 | |
def __init__( | |
self, | |
model_path: str, | |
kid_template_path: str = "", | |
num_expression_coeffs: int = 10, | |
create_expression: bool = True, | |
expression: Optional[Tensor] = None, | |
create_jaw_pose: bool = True, | |
jaw_pose: Optional[Tensor] = None, | |
create_leye_pose: bool = True, | |
leye_pose: Optional[Tensor] = None, | |
create_reye_pose=True, | |
reye_pose: Optional[Tensor] = None, | |
use_face_contour: bool = False, | |
batch_size: int = 1, | |
gender: str = "neutral", | |
age: str = "adult", | |
dtype=torch.float32, | |
ext: str = "npz", | |
**kwargs, | |
) -> None: | |
"""SMPLX model constructor | |
Parameters | |
---------- | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored | |
num_expression_coeffs: int, optional | |
Number of expression components to use | |
(default = 10). | |
create_expression: bool, optional | |
Flag for creating a member variable for the expression space | |
(default = True). | |
expression: torch.tensor, optional, Bx10 | |
The default value for the expression member variable. | |
(default = None) | |
create_jaw_pose: bool, optional | |
Flag for creating a member variable for the jaw pose. | |
(default = False) | |
jaw_pose: torch.tensor, optional, Bx3 | |
The default value for the jaw pose variable. | |
(default = None) | |
create_leye_pose: bool, optional | |
Flag for creating a member variable for the left eye pose. | |
(default = False) | |
leye_pose: torch.tensor, optional, Bx10 | |
The default value for the left eye pose variable. | |
(default = None) | |
create_reye_pose: bool, optional | |
Flag for creating a member variable for the right eye pose. | |
(default = False) | |
reye_pose: torch.tensor, optional, Bx10 | |
The default value for the right eye pose variable. | |
(default = None) | |
use_face_contour: bool, optional | |
Whether to compute the keypoints that form the facial contour | |
batch_size: int, optional | |
The batch size used for creating the member variables | |
gender: str, optional | |
Which gender to load | |
dtype: torch.dtype | |
The data type for the created variables | |
""" | |
# Load the model | |
from huggingface_hub import hf_hub_download | |
model_fn = "SMPLX_{}.{ext}".format(gender.upper(), ext=ext) | |
smplx_path = hf_hub_download( | |
repo_id=model_path, use_auth_token=os.environ["ICON"], filename=f"models/{model_fn}" | |
) | |
if ext == "pkl": | |
with open(smplx_path, "rb") as smplx_file: | |
model_data = pickle.load(smplx_file, encoding="latin1") | |
elif ext == "npz": | |
model_data = np.load(smplx_path, allow_pickle=True) | |
else: | |
raise ValueError("Unknown extension: {}".format(ext)) | |
data_struct = Struct(**model_data) | |
super(SMPLX, self).__init__( | |
model_path=model_path, | |
kid_template_path=kid_template_path, | |
data_struct=data_struct, | |
dtype=dtype, | |
batch_size=batch_size, | |
vertex_ids=VERTEX_IDS["smplx"], | |
gender=gender, | |
age=age, | |
ext=ext, | |
**kwargs, | |
) | |
lmk_faces_idx = data_struct.lmk_faces_idx | |
self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long)) | |
lmk_bary_coords = data_struct.lmk_bary_coords | |
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) | |
self.use_face_contour = use_face_contour | |
if self.use_face_contour: | |
dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx | |
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) | |
self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx) | |
dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords | |
dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, dtype=dtype) | |
self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_bary_coords) | |
neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) | |
self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long)) | |
if create_jaw_pose: | |
if jaw_pose is None: | |
default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) | |
jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) | |
self.register_parameter("jaw_pose", jaw_pose_param) | |
if create_leye_pose: | |
if leye_pose is None: | |
default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_leye_pose = torch.tensor(leye_pose, dtype=dtype) | |
leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) | |
self.register_parameter("leye_pose", leye_pose_param) | |
if create_reye_pose: | |
if reye_pose is None: | |
default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_reye_pose = torch.tensor(reye_pose, dtype=dtype) | |
reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) | |
self.register_parameter("reye_pose", reye_pose_param) | |
shapedirs = data_struct.shapedirs | |
if len(shapedirs.shape) < 3: | |
shapedirs = shapedirs[:, :, None] | |
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM: | |
# print(f'WARNING: You are using a {self.name()} model, with only' | |
# ' 10 shape and 10 expression coefficients.') | |
expr_start_idx = 10 | |
expr_end_idx = 20 | |
num_expression_coeffs = min(num_expression_coeffs, 10) | |
else: | |
expr_start_idx = self.SHAPE_SPACE_DIM | |
expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs | |
num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM) | |
self._num_expression_coeffs = num_expression_coeffs | |
expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] | |
self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype)) | |
if create_expression: | |
if expression is None: | |
default_expression = torch.zeros([batch_size, self.num_expression_coeffs], | |
dtype=dtype) | |
else: | |
default_expression = torch.tensor(expression, dtype=dtype) | |
expression_param = nn.Parameter(default_expression, requires_grad=True) | |
self.register_parameter("expression", expression_param) | |
def name(self) -> str: | |
return "SMPL-X" | |
def num_expression_coeffs(self): | |
return self._num_expression_coeffs | |
def create_mean_pose(self, data_struct, flat_hand_mean=False): | |
# Create the array for the mean pose. If flat_hand is false, then use | |
# the mean that is given by the data, rather than the flat open hand | |
global_orient_mean = torch.zeros([3], dtype=self.dtype) | |
body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) | |
jaw_pose_mean = torch.zeros([3], dtype=self.dtype) | |
leye_pose_mean = torch.zeros([3], dtype=self.dtype) | |
reye_pose_mean = torch.zeros([3], dtype=self.dtype) | |
pose_mean = np.concatenate( | |
[ | |
global_orient_mean, | |
body_pose_mean, | |
jaw_pose_mean, | |
leye_pose_mean, | |
reye_pose_mean, | |
self.left_hand_mean, | |
self.right_hand_mean, | |
], | |
axis=0, | |
) | |
return pose_mean | |
def extra_repr(self): | |
msg = super(SMPLX, self).extra_repr() | |
msg = [msg, f"Number of Expression Coefficients: {self.num_expression_coeffs}"] | |
return "\n".join(msg) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
body_pose: Optional[Tensor] = None, | |
left_hand_pose: Optional[Tensor] = None, | |
right_hand_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
expression: Optional[Tensor] = None, | |
jaw_pose: Optional[Tensor] = None, | |
leye_pose: Optional[Tensor] = None, | |
reye_pose: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
return_joint_transformation: bool = False, | |
return_vertex_transformation: bool = False, | |
pose_type: str = 'posed', | |
**kwargs, | |
) -> SMPLXOutput: | |
""" | |
Forward pass for the SMPLX model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable and use it as the global | |
rotation of the body. Useful if someone wishes to predicts this | |
with an external model. (default=None) | |
betas: torch.tensor, optional, shape BxN_b | |
If given, ignore the member variable `betas` and use it | |
instead. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
expression: torch.tensor, optional, shape BxN_e | |
If given, ignore the member variable `expression` and use it | |
instead. For example, it can used if expression parameters | |
`expression` are predicted from some external model. | |
body_pose: torch.tensor, optional, shape Bx(J*3) | |
If given, ignore the member variable `body_pose` and use it | |
instead. For example, it can used if someone predicts the | |
pose of the body joints are predicted from some external model. | |
It should be a tensor that contains joint rotations in | |
axis-angle format. (default=None) | |
left_hand_pose: torch.tensor, optional, shape BxP | |
If given, ignore the member variable `left_hand_pose` and | |
use this instead. It should either contain PCA coefficients or | |
joint rotations in axis-angle format. | |
right_hand_pose: torch.tensor, optional, shape BxP | |
If given, ignore the member variable `right_hand_pose` and | |
use this instead. It should either contain PCA coefficients or | |
joint rotations in axis-angle format. | |
jaw_pose: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable `jaw_pose` and | |
use this instead. It should either joint rotations in | |
axis-angle format. | |
transl: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable `transl` and use it | |
instead. For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full axis-angle pose vector (default=False) | |
Returns | |
------- | |
output: ModelOutput | |
A named tuple of type `ModelOutput` | |
""" | |
# If no shape and pose parameters are passed along, then use the | |
# ones from the module | |
global_orient = (global_orient if global_orient is not None else self.global_orient) | |
body_pose = body_pose if body_pose is not None else self.body_pose | |
betas = betas if betas is not None else self.betas | |
left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) | |
right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) | |
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose | |
leye_pose = leye_pose if leye_pose is not None else self.leye_pose | |
reye_pose = reye_pose if reye_pose is not None else self.reye_pose | |
expression = expression if expression is not None else self.expression | |
apply_trans = transl is not None or hasattr(self, "transl") | |
if transl is None: | |
if hasattr(self, "transl"): | |
transl = self.transl | |
if self.use_pca: | |
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) | |
right_hand_pose = torch.einsum( | |
"bi,ij->bj", [right_hand_pose, self.right_hand_components] | |
) | |
full_pose = torch.cat( | |
[ | |
global_orient, | |
body_pose, | |
jaw_pose, | |
leye_pose, | |
reye_pose, | |
left_hand_pose, | |
right_hand_pose, | |
], | |
dim=1, | |
) | |
if pose_type == "t-pose": | |
full_pose *= 0.0 | |
elif pose_type == "a-pose": | |
body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3) | |
body_pose[:, 15] = torch.tensor([0., 0., -45 * np.pi / 180.]) | |
body_pose[:, 16] = torch.tensor([0., 0., 45 * np.pi / 180.]) | |
body_pose = body_pose.view(body_pose.shape[0], -1) | |
full_pose = torch.cat( | |
[ | |
global_orient * 0., | |
body_pose, | |
jaw_pose * 0., | |
leye_pose * 0., | |
reye_pose * 0., | |
left_hand_pose * 0., | |
right_hand_pose * 0., | |
], | |
dim=1, | |
) | |
elif pose_type == "da-pose": | |
body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3) | |
body_pose[:, 0] = torch.tensor([0., 0., 30 * np.pi / 180.]) | |
body_pose[:, 1] = torch.tensor([0., 0., -30 * np.pi / 180.]) | |
body_pose = body_pose.view(body_pose.shape[0], -1) | |
full_pose = torch.cat( | |
[ | |
global_orient * 0., | |
body_pose, | |
jaw_pose * 0., | |
leye_pose * 0., | |
reye_pose * 0., | |
left_hand_pose * 0., | |
right_hand_pose * 0., | |
], | |
dim=1, | |
) | |
# Add the mean pose of the model. Does not affect the body, only the | |
# hands when flat_hand_mean == False | |
# full_pose += self.pose_mean | |
batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) | |
# Concatenate the shape and expression coefficients | |
scale = int(batch_size / betas.shape[0]) | |
if scale > 1: | |
betas = betas.expand(scale, -1) | |
shape_components = torch.cat([betas, expression], dim=-1) | |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) | |
if return_joint_transformation or return_vertex_transformation: | |
vertices, joints, joint_transformation, vertex_transformation = lbs( | |
shape_components, | |
full_pose, | |
self.v_template, | |
shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=pose2rot, | |
return_transformation=True, | |
) | |
else: | |
vertices, joints = lbs( | |
shape_components, | |
full_pose, | |
self.v_template, | |
shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=pose2rot, | |
) | |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) | |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) | |
if self.use_face_contour: | |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( | |
vertices, | |
full_pose, | |
self.dynamic_lmk_faces_idx, | |
self.dynamic_lmk_bary_coords, | |
self.neck_kin_chain, | |
pose2rot=True, | |
) | |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords | |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) | |
lmk_bary_coords = torch.cat([ | |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords | |
], 1) | |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) | |
# Add any extra joints that might be needed | |
joints = self.vertex_joint_selector(vertices, joints) | |
# Add the landmarks to the joints | |
joints = torch.cat([joints, landmarks], dim=1) | |
# Map the joints to the current dataset | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints=joints, vertices=vertices) | |
if apply_trans: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = SMPLXOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints, | |
betas=betas, | |
expression=expression, | |
global_orient=global_orient, | |
body_pose=body_pose, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
jaw_pose=jaw_pose, | |
full_pose=full_pose if return_full_pose else None, | |
joint_transformation=joint_transformation if return_joint_transformation else None, | |
vertex_transformation=vertex_transformation if return_vertex_transformation else None, | |
) | |
return output | |
class SMPLXLayer(SMPLX): | |
def __init__(self, *args, **kwargs) -> None: | |
# Just create a SMPLX module without any member variables | |
super(SMPLXLayer, self).__init__( | |
create_global_orient=False, | |
create_body_pose=False, | |
create_left_hand_pose=False, | |
create_right_hand_pose=False, | |
create_jaw_pose=False, | |
create_leye_pose=False, | |
create_reye_pose=False, | |
create_betas=False, | |
create_expression=False, | |
create_transl=False, | |
*args, | |
**kwargs, | |
) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
body_pose: Optional[Tensor] = None, | |
left_hand_pose: Optional[Tensor] = None, | |
right_hand_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
expression: Optional[Tensor] = None, | |
jaw_pose: Optional[Tensor] = None, | |
leye_pose: Optional[Tensor] = None, | |
reye_pose: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs, | |
) -> SMPLXOutput: | |
""" | |
Forward pass for the SMPLX model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3x3 | |
If given, ignore the member variable and use it as the global | |
rotation of the body. Useful if someone wishes to predicts this | |
with an external model. It is expected to be in rotation matrix | |
format. (default=None) | |
betas: torch.tensor, optional, shape BxN_b | |
If given, ignore the member variable `betas` and use it | |
instead. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
expression: torch.tensor, optional, shape BxN_e | |
Expression coefficients. | |
For example, it can used if expression parameters | |
`expression` are predicted from some external model. | |
body_pose: torch.tensor, optional, shape BxJx3x3 | |
If given, ignore the member variable `body_pose` and use it | |
instead. For example, it can used if someone predicts the | |
pose of the body joints are predicted from some external model. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
left_hand_pose: torch.tensor, optional, shape Bx15x3x3 | |
If given, contains the pose of the left hand. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
right_hand_pose: torch.tensor, optional, shape Bx15x3x3 | |
If given, contains the pose of the right hand. | |
It should be a tensor that contains joint rotations in | |
rotation matrix format. (default=None) | |
jaw_pose: torch.tensor, optional, shape Bx3x3 | |
Jaw pose. It should either joint rotations in | |
rotation matrix format. | |
transl: torch.tensor, optional, shape Bx3 | |
Translation vector of the body. | |
For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full pose vector (default=False) | |
Returns | |
------- | |
output: ModelOutput | |
A data class that contains the posed vertices and joints | |
""" | |
device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
model_vars = [ | |
betas, | |
global_orient, | |
body_pose, | |
transl, | |
expression, | |
left_hand_pose, | |
right_hand_pose, | |
jaw_pose, | |
] | |
batch_size = 1 | |
for var in model_vars: | |
if var is None: | |
continue | |
batch_size = max(batch_size, len(var)) | |
if global_orient is None: | |
global_orient = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if body_pose is None: | |
body_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, | |
3).expand(batch_size, self.NUM_BODY_JOINTS, -1, | |
-1).contiguous() | |
) | |
if left_hand_pose is None: | |
left_hand_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() | |
) | |
if right_hand_pose is None: | |
right_hand_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() | |
) | |
if jaw_pose is None: | |
jaw_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if leye_pose is None: | |
leye_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if reye_pose is None: | |
reye_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if expression is None: | |
expression = torch.zeros([batch_size, self.num_expression_coeffs], | |
dtype=dtype, | |
device=device) | |
if betas is None: | |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) | |
if transl is None: | |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) | |
# Concatenate all pose vectors | |
full_pose = torch.cat( | |
[ | |
global_orient.reshape(-1, 1, 3, 3), | |
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), | |
jaw_pose.reshape(-1, 1, 3, 3), | |
leye_pose.reshape(-1, 1, 3, 3), | |
reye_pose.reshape(-1, 1, 3, 3), | |
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), | |
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), | |
], | |
dim=1, | |
) | |
shape_components = torch.cat([betas, expression], dim=-1) | |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) | |
vertices, joints = lbs( | |
shape_components, | |
full_pose, | |
self.v_template, | |
shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=False, | |
) | |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) | |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(batch_size, 1, 1) | |
if self.use_face_contour: | |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( | |
vertices, | |
full_pose, | |
self.dynamic_lmk_faces_idx, | |
self.dynamic_lmk_bary_coords, | |
self.neck_kin_chain, | |
pose2rot=False, | |
) | |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords | |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) | |
lmk_bary_coords = torch.cat([ | |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords | |
], 1) | |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) | |
# Add any extra joints that might be needed | |
joints = self.vertex_joint_selector(vertices, joints) | |
# Add the landmarks to the joints | |
joints = torch.cat([joints, landmarks], dim=1) | |
# Map the joints to the current dataset | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints=joints, vertices=vertices) | |
if transl is not None: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = SMPLXOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints, | |
betas=betas, | |
expression=expression, | |
global_orient=global_orient, | |
body_pose=body_pose, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
jaw_pose=jaw_pose, | |
transl=transl, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class MANO(SMPL): | |
# The hand joints are replaced by MANO | |
NUM_BODY_JOINTS = 1 | |
NUM_HAND_JOINTS = 15 | |
NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS | |
def __init__( | |
self, | |
model_path: str, | |
is_rhand: bool = True, | |
data_struct: Optional[Struct] = None, | |
create_hand_pose: bool = True, | |
hand_pose: Optional[Tensor] = None, | |
use_pca: bool = True, | |
num_pca_comps: int = 6, | |
flat_hand_mean: bool = False, | |
batch_size: int = 1, | |
dtype=torch.float32, | |
vertex_ids=None, | |
use_compressed: bool = True, | |
ext: str = "pkl", | |
**kwargs, | |
) -> None: | |
"""MANO model constructor | |
Parameters | |
---------- | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored | |
data_struct: Strct | |
A struct object. If given, then the parameters of the model are | |
read from the object. Otherwise, the model tries to read the | |
parameters from the given `model_path`. (default = None) | |
create_hand_pose: bool, optional | |
Flag for creating a member variable for the pose of the right | |
hand. (default = True) | |
hand_pose: torch.tensor, optional, BxP | |
The default value for the right hand pose member variable. | |
(default = None) | |
num_pca_comps: int, optional | |
The number of PCA components to use for each hand. | |
(default = 6) | |
flat_hand_mean: bool, optional | |
If False, then the pose of the hand is initialized to False. | |
batch_size: int, optional | |
The batch size used for creating the member variables | |
dtype: torch.dtype, optional | |
The data type for the created variables | |
vertex_ids: dict, optional | |
A dictionary containing the indices of the extra vertices that | |
will be selected | |
""" | |
self.num_pca_comps = num_pca_comps | |
self.is_rhand = is_rhand | |
# If no data structure is passed, then load the data from the given | |
# model folder | |
if data_struct is None: | |
# Load the model | |
if osp.isdir(model_path): | |
model_fn = "MANO_{}.{ext}".format("RIGHT" if is_rhand else "LEFT", ext=ext) | |
mano_path = os.path.join(model_path, model_fn) | |
else: | |
mano_path = model_path | |
self.is_rhand = (True if "RIGHT" in os.path.basename(model_path) else False) | |
assert osp.exists(mano_path), "Path {} does not exist!".format(mano_path) | |
if ext == "pkl": | |
with open(mano_path, "rb") as mano_file: | |
model_data = pickle.load(mano_file, encoding="latin1") | |
elif ext == "npz": | |
model_data = np.load(mano_path, allow_pickle=True) | |
else: | |
raise ValueError("Unknown extension: {}".format(ext)) | |
data_struct = Struct(**model_data) | |
if vertex_ids is None: | |
vertex_ids = VERTEX_IDS["smplh"] | |
super(MANO, self).__init__( | |
model_path=model_path, | |
data_struct=data_struct, | |
batch_size=batch_size, | |
vertex_ids=vertex_ids, | |
use_compressed=use_compressed, | |
dtype=dtype, | |
ext=ext, | |
**kwargs, | |
) | |
# add only MANO tips to the extra joints | |
self.vertex_joint_selector.extra_joints_idxs = to_tensor( | |
list(VERTEX_IDS["mano"].values()), dtype=torch.long | |
) | |
self.use_pca = use_pca | |
self.num_pca_comps = num_pca_comps | |
if self.num_pca_comps == 45: | |
self.use_pca = False | |
self.flat_hand_mean = flat_hand_mean | |
hand_components = data_struct.hands_components[:num_pca_comps] | |
self.np_hand_components = hand_components | |
if self.use_pca: | |
self.register_buffer("hand_components", torch.tensor(hand_components, dtype=dtype)) | |
if self.flat_hand_mean: | |
hand_mean = np.zeros_like(data_struct.hands_mean) | |
else: | |
hand_mean = data_struct.hands_mean | |
self.register_buffer("hand_mean", to_tensor(hand_mean, dtype=self.dtype)) | |
# Create the buffers for the pose of the left hand | |
hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS | |
if create_hand_pose: | |
if hand_pose is None: | |
default_hand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) | |
else: | |
default_hand_pose = torch.tensor(hand_pose, dtype=dtype) | |
hand_pose_param = nn.Parameter(default_hand_pose, requires_grad=True) | |
self.register_parameter("hand_pose", hand_pose_param) | |
# Create the buffer for the mean pose. | |
pose_mean = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean) | |
pose_mean_tensor = pose_mean.clone().to(dtype) | |
# pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype) | |
self.register_buffer("pose_mean", pose_mean_tensor) | |
def name(self) -> str: | |
return "MANO" | |
def create_mean_pose(self, data_struct, flat_hand_mean=False): | |
# Create the array for the mean pose. If flat_hand is false, then use | |
# the mean that is given by the data, rather than the flat open hand | |
global_orient_mean = torch.zeros([3], dtype=self.dtype) | |
pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) | |
return pose_mean | |
def extra_repr(self): | |
msg = [super(MANO, self).extra_repr()] | |
if self.use_pca: | |
msg.append(f"Number of PCA components: {self.num_pca_comps}") | |
msg.append(f"Flat hand mean: {self.flat_hand_mean}") | |
return "\n".join(msg) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
hand_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs, | |
) -> MANOOutput: | |
"""Forward pass for the MANO model""" | |
# If no shape and pose parameters are passed along, then use the | |
# ones from the module | |
global_orient = (global_orient if global_orient is not None else self.global_orient) | |
betas = betas if betas is not None else self.betas | |
hand_pose = hand_pose if hand_pose is not None else self.hand_pose | |
apply_trans = transl is not None or hasattr(self, "transl") | |
if transl is None: | |
if hasattr(self, "transl"): | |
transl = self.transl | |
if self.use_pca: | |
hand_pose = torch.einsum("bi,ij->bj", [hand_pose, self.hand_components]) | |
full_pose = torch.cat([global_orient, hand_pose], dim=1) | |
full_pose += self.pose_mean | |
vertices, joints = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=True, | |
) | |
# # Add pre-selected extra joints that might be needed | |
# joints = self.vertex_joint_selector(vertices, joints) | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints) | |
if apply_trans: | |
joints = joints + transl.unsqueeze(dim=1) | |
vertices = vertices + transl.unsqueeze(dim=1) | |
output = MANOOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints if return_verts else None, | |
betas=betas, | |
global_orient=global_orient, | |
hand_pose=hand_pose, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class MANOLayer(MANO): | |
def __init__(self, *args, **kwargs) -> None: | |
"""MANO as a layer model constructor""" | |
super(MANOLayer, self).__init__( | |
create_global_orient=False, | |
create_hand_pose=False, | |
create_betas=False, | |
create_transl=False, | |
*args, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "MANO" | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
hand_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs, | |
) -> MANOOutput: | |
"""Forward pass for the MANO model""" | |
device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
if global_orient is None: | |
batch_size = 1 | |
global_orient = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
else: | |
batch_size = global_orient.shape[0] | |
if hand_pose is None: | |
hand_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() | |
) | |
if betas is None: | |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) | |
if transl is None: | |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) | |
full_pose = torch.cat([global_orient, hand_pose], dim=1) | |
vertices, joints = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=False, | |
) | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints) | |
if transl is not None: | |
joints = joints + transl.unsqueeze(dim=1) | |
vertices = vertices + transl.unsqueeze(dim=1) | |
output = MANOOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints if return_verts else None, | |
betas=betas, | |
global_orient=global_orient, | |
hand_pose=hand_pose, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class FLAME(SMPL): | |
NUM_JOINTS = 5 | |
SHAPE_SPACE_DIM = 300 | |
EXPRESSION_SPACE_DIM = 100 | |
NECK_IDX = 0 | |
def __init__( | |
self, | |
model_path: str, | |
data_struct=None, | |
num_expression_coeffs=10, | |
create_expression: bool = True, | |
expression: Optional[Tensor] = None, | |
create_neck_pose: bool = True, | |
neck_pose: Optional[Tensor] = None, | |
create_jaw_pose: bool = True, | |
jaw_pose: Optional[Tensor] = None, | |
create_leye_pose: bool = True, | |
leye_pose: Optional[Tensor] = None, | |
create_reye_pose=True, | |
reye_pose: Optional[Tensor] = None, | |
use_face_contour=False, | |
batch_size: int = 1, | |
gender: str = "neutral", | |
dtype: torch.dtype = torch.float32, | |
ext="pkl", | |
**kwargs, | |
) -> None: | |
"""FLAME model constructor | |
Parameters | |
---------- | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored | |
num_expression_coeffs: int, optional | |
Number of expression components to use | |
(default = 10). | |
create_expression: bool, optional | |
Flag for creating a member variable for the expression space | |
(default = True). | |
expression: torch.tensor, optional, Bx10 | |
The default value for the expression member variable. | |
(default = None) | |
create_neck_pose: bool, optional | |
Flag for creating a member variable for the neck pose. | |
(default = False) | |
neck_pose: torch.tensor, optional, Bx3 | |
The default value for the neck pose variable. | |
(default = None) | |
create_jaw_pose: bool, optional | |
Flag for creating a member variable for the jaw pose. | |
(default = False) | |
jaw_pose: torch.tensor, optional, Bx3 | |
The default value for the jaw pose variable. | |
(default = None) | |
create_leye_pose: bool, optional | |
Flag for creating a member variable for the left eye pose. | |
(default = False) | |
leye_pose: torch.tensor, optional, Bx10 | |
The default value for the left eye pose variable. | |
(default = None) | |
create_reye_pose: bool, optional | |
Flag for creating a member variable for the right eye pose. | |
(default = False) | |
reye_pose: torch.tensor, optional, Bx10 | |
The default value for the right eye pose variable. | |
(default = None) | |
use_face_contour: bool, optional | |
Whether to compute the keypoints that form the facial contour | |
batch_size: int, optional | |
The batch size used for creating the member variables | |
gender: str, optional | |
Which gender to load | |
dtype: torch.dtype | |
The data type for the created variables | |
""" | |
model_fn = f"FLAME_{gender.upper()}.{ext}" | |
flame_path = os.path.join(model_path, model_fn) | |
assert osp.exists(flame_path), "Path {} does not exist!".format(flame_path) | |
if ext == "npz": | |
file_data = np.load(flame_path, allow_pickle=True) | |
elif ext == "pkl": | |
with open(flame_path, "rb") as smpl_file: | |
file_data = pickle.load(smpl_file, encoding="latin1") | |
else: | |
raise ValueError("Unknown extension: {}".format(ext)) | |
data_struct = Struct(**file_data) | |
super(FLAME, self).__init__( | |
model_path=model_path, | |
data_struct=data_struct, | |
dtype=dtype, | |
batch_size=batch_size, | |
gender=gender, | |
ext=ext, | |
**kwargs, | |
) | |
self.use_face_contour = use_face_contour | |
self.vertex_joint_selector.extra_joints_idxs = to_tensor([], dtype=torch.long) | |
if create_neck_pose: | |
if neck_pose is None: | |
default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_neck_pose = torch.tensor(neck_pose, dtype=dtype) | |
neck_pose_param = nn.Parameter(default_neck_pose, requires_grad=True) | |
self.register_parameter("neck_pose", neck_pose_param) | |
if create_jaw_pose: | |
if jaw_pose is None: | |
default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) | |
jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) | |
self.register_parameter("jaw_pose", jaw_pose_param) | |
if create_leye_pose: | |
if leye_pose is None: | |
default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_leye_pose = torch.tensor(leye_pose, dtype=dtype) | |
leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) | |
self.register_parameter("leye_pose", leye_pose_param) | |
if create_reye_pose: | |
if reye_pose is None: | |
default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) | |
else: | |
default_reye_pose = torch.tensor(reye_pose, dtype=dtype) | |
reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) | |
self.register_parameter("reye_pose", reye_pose_param) | |
shapedirs = data_struct.shapedirs | |
if len(shapedirs.shape) < 3: | |
shapedirs = shapedirs[:, :, None] | |
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM: | |
# print(f'WARNING: You are using a {self.name()} model, with only' | |
# ' 10 shape and 10 expression coefficients.') | |
expr_start_idx = 10 | |
expr_end_idx = 20 | |
num_expression_coeffs = min(num_expression_coeffs, 10) | |
else: | |
expr_start_idx = self.SHAPE_SPACE_DIM | |
expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs | |
num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM) | |
self._num_expression_coeffs = num_expression_coeffs | |
expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] | |
self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype)) | |
if create_expression: | |
if expression is None: | |
default_expression = torch.zeros([batch_size, self.num_expression_coeffs], | |
dtype=dtype) | |
else: | |
default_expression = torch.tensor(expression, dtype=dtype) | |
expression_param = nn.Parameter(default_expression, requires_grad=True) | |
self.register_parameter("expression", expression_param) | |
# The pickle file that contains the barycentric coordinates for | |
# regressing the landmarks | |
landmark_bcoord_filename = osp.join(model_path, "flame_static_embedding.pkl") | |
with open(landmark_bcoord_filename, "rb") as fp: | |
landmarks_data = pickle.load(fp, encoding="latin1") | |
lmk_faces_idx = landmarks_data["lmk_face_idx"].astype(np.int64) | |
self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long)) | |
lmk_bary_coords = landmarks_data["lmk_b_coords"] | |
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) | |
if self.use_face_contour: | |
face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy") | |
contour_embeddings = np.load(face_contour_path, allow_pickle=True, | |
encoding="latin1")[()] | |
dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64) | |
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) | |
self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx) | |
dynamic_lmk_b_coords = torch.tensor(contour_embeddings["lmk_b_coords"], dtype=dtype) | |
self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_b_coords) | |
neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) | |
self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long)) | |
def num_expression_coeffs(self): | |
return self._num_expression_coeffs | |
def name(self) -> str: | |
return "FLAME" | |
def extra_repr(self): | |
msg = [ | |
super(FLAME, self).extra_repr(), | |
f"Number of Expression Coefficients: {self.num_expression_coeffs}", | |
f"Use face contour: {self.use_face_contour}", | |
] | |
return "\n".join(msg) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
neck_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
expression: Optional[Tensor] = None, | |
jaw_pose: Optional[Tensor] = None, | |
leye_pose: Optional[Tensor] = None, | |
reye_pose: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
**kwargs, | |
) -> FLAMEOutput: | |
""" | |
Forward pass for the SMPLX model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable and use it as the global | |
rotation of the body. Useful if someone wishes to predicts this | |
with an external model. (default=None) | |
betas: torch.tensor, optional, shape Bx10 | |
If given, ignore the member variable `betas` and use it | |
instead. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
expression: torch.tensor, optional, shape Bx10 | |
If given, ignore the member variable `expression` and use it | |
instead. For example, it can used if expression parameters | |
`expression` are predicted from some external model. | |
jaw_pose: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable `jaw_pose` and | |
use this instead. It should either joint rotations in | |
axis-angle format. | |
jaw_pose: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable `jaw_pose` and | |
use this instead. It should either joint rotations in | |
axis-angle format. | |
transl: torch.tensor, optional, shape Bx3 | |
If given, ignore the member variable `transl` and use it | |
instead. For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full axis-angle pose vector (default=False) | |
Returns | |
------- | |
output: ModelOutput | |
A named tuple of type `ModelOutput` | |
""" | |
# If no shape and pose parameters are passed along, then use the | |
# ones from the module | |
global_orient = (global_orient if global_orient is not None else self.global_orient) | |
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose | |
neck_pose = neck_pose if neck_pose is not None else self.neck_pose | |
leye_pose = leye_pose if leye_pose is not None else self.leye_pose | |
reye_pose = reye_pose if reye_pose is not None else self.reye_pose | |
betas = betas if betas is not None else self.betas | |
expression = expression if expression is not None else self.expression | |
apply_trans = transl is not None or hasattr(self, "transl") | |
if transl is None: | |
if hasattr(self, "transl"): | |
transl = self.transl | |
full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) | |
batch_size = max(betas.shape[0], global_orient.shape[0], jaw_pose.shape[0]) | |
# Concatenate the shape and expression coefficients | |
scale = int(batch_size / betas.shape[0]) | |
if scale > 1: | |
betas = betas.expand(scale, -1) | |
shape_components = torch.cat([betas, expression], dim=-1) | |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) | |
vertices, joints = lbs( | |
shape_components, | |
full_pose, | |
self.v_template, | |
shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=pose2rot, | |
) | |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) | |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) | |
if self.use_face_contour: | |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( | |
vertices, | |
full_pose, | |
self.dynamic_lmk_faces_idx, | |
self.dynamic_lmk_bary_coords, | |
self.neck_kin_chain, | |
pose2rot=True, | |
) | |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords | |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) | |
lmk_bary_coords = torch.cat([ | |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords | |
], 1) | |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) | |
# Add any extra joints that might be needed | |
joints = self.vertex_joint_selector(vertices, joints) | |
# Add the landmarks to the joints | |
joints = torch.cat([joints, landmarks], dim=1) | |
# Map the joints to the current dataset | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints=joints, vertices=vertices) | |
if apply_trans: | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = FLAMEOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints, | |
betas=betas, | |
expression=expression, | |
global_orient=global_orient, | |
neck_pose=neck_pose, | |
jaw_pose=jaw_pose, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
class FLAMELayer(FLAME): | |
def __init__(self, *args, **kwargs) -> None: | |
""" FLAME as a layer model constructor """ | |
super(FLAMELayer, self).__init__( | |
create_betas=False, | |
create_expression=False, | |
create_global_orient=False, | |
create_neck_pose=False, | |
create_jaw_pose=False, | |
create_leye_pose=False, | |
create_reye_pose=False, | |
*args, | |
**kwargs, | |
) | |
def forward( | |
self, | |
betas: Optional[Tensor] = None, | |
global_orient: Optional[Tensor] = None, | |
neck_pose: Optional[Tensor] = None, | |
transl: Optional[Tensor] = None, | |
expression: Optional[Tensor] = None, | |
jaw_pose: Optional[Tensor] = None, | |
leye_pose: Optional[Tensor] = None, | |
reye_pose: Optional[Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
pose2rot: bool = True, | |
**kwargs, | |
) -> FLAMEOutput: | |
""" | |
Forward pass for the SMPLX model | |
Parameters | |
---------- | |
global_orient: torch.tensor, optional, shape Bx3x3 | |
Global rotation of the body. Useful if someone wishes to | |
predicts this with an external model. It is expected to be in | |
rotation matrix format. (default=None) | |
betas: torch.tensor, optional, shape BxN_b | |
Shape parameters. For example, it can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
expression: torch.tensor, optional, shape BxN_e | |
If given, ignore the member variable `expression` and use it | |
instead. For example, it can used if expression parameters | |
`expression` are predicted from some external model. | |
jaw_pose: torch.tensor, optional, shape Bx3x3 | |
Jaw pose. It should either joint rotations in | |
rotation matrix format. | |
transl: torch.tensor, optional, shape Bx3 | |
Translation vector of the body. | |
For example, it can used if the translation | |
`transl` is predicted from some external model. | |
(default=None) | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
return_full_pose: bool, optional | |
Returns the full axis-angle pose vector (default=False) | |
Returns | |
------- | |
output: ModelOutput | |
A named tuple of type `ModelOutput` | |
""" | |
device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
if global_orient is None: | |
batch_size = 1 | |
global_orient = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
else: | |
batch_size = global_orient.shape[0] | |
if neck_pose is None: | |
neck_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() | |
) | |
if jaw_pose is None: | |
jaw_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if leye_pose is None: | |
leye_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if reye_pose is None: | |
reye_pose = ( | |
torch.eye(3, device=device, | |
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() | |
) | |
if betas is None: | |
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) | |
if expression is None: | |
expression = torch.zeros([batch_size, self.num_expression_coeffs], | |
dtype=dtype, | |
device=device) | |
if transl is None: | |
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) | |
full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) | |
shape_components = torch.cat([betas, expression], dim=-1) | |
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) | |
vertices, joints = lbs( | |
shape_components, | |
full_pose, | |
self.v_template, | |
shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=False, | |
) | |
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) | |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) | |
if self.use_face_contour: | |
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( | |
vertices, | |
full_pose, | |
self.dynamic_lmk_faces_idx, | |
self.dynamic_lmk_bary_coords, | |
self.neck_kin_chain, | |
pose2rot=False, | |
) | |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords | |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) | |
lmk_bary_coords = torch.cat([ | |
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords | |
], 1) | |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) | |
# Add any extra joints that might be needed | |
joints = self.vertex_joint_selector(vertices, joints) | |
# Add the landmarks to the joints | |
joints = torch.cat([joints, landmarks], dim=1) | |
# Map the joints to the current dataset | |
if self.joint_mapper is not None: | |
joints = self.joint_mapper(joints=joints, vertices=vertices) | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
output = FLAMEOutput( | |
vertices=vertices if return_verts else None, | |
joints=joints, | |
betas=betas, | |
expression=expression, | |
global_orient=global_orient, | |
neck_pose=neck_pose, | |
jaw_pose=jaw_pose, | |
full_pose=full_pose if return_full_pose else None, | |
) | |
return output | |
def build_layer(model_path: str, | |
model_type: str = "smpl", | |
**kwargs) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: | |
"""Method for creating a model from a path and a model type | |
Parameters | |
---------- | |
model_path: str | |
Either the path to the model you wish to load or a folder, | |
where each subfolder contains the differents types, i.e.: | |
model_path: | |
| | |
|-- smpl | |
|-- SMPL_FEMALE | |
|-- SMPL_NEUTRAL | |
|-- SMPL_MALE | |
|-- smplh | |
|-- SMPLH_FEMALE | |
|-- SMPLH_MALE | |
|-- smplx | |
|-- SMPLX_FEMALE | |
|-- SMPLX_NEUTRAL | |
|-- SMPLX_MALE | |
|-- mano | |
|-- MANO RIGHT | |
|-- MANO LEFT | |
|-- flame | |
|-- FLAME_FEMALE | |
|-- FLAME_MALE | |
|-- FLAME_NEUTRAL | |
model_type: str, optional | |
When model_path is a folder, then this parameter specifies the | |
type of model to be loaded | |
**kwargs: dict | |
Keyword arguments | |
Returns | |
------- | |
body_model: nn.Module | |
The PyTorch module that implements the corresponding body model | |
Raises | |
------ | |
ValueError: In case the model type is not one of SMPL, SMPLH, | |
SMPLX, MANO or FLAME | |
""" | |
if osp.isdir(model_path): | |
model_path = os.path.join(model_path, model_type) | |
else: | |
model_type = osp.basename(model_path).split("_")[0].lower() | |
if model_type.lower() == "smpl": | |
return SMPLLayer(model_path, **kwargs) | |
elif model_type.lower() == "smplh": | |
return SMPLHLayer(model_path, **kwargs) | |
elif model_type.lower() == "smplx": | |
return SMPLXLayer(model_path, **kwargs) | |
elif "mano" in model_type.lower(): | |
return MANOLayer(model_path, **kwargs) | |
elif "flame" in model_type.lower(): | |
return FLAMELayer(model_path, **kwargs) | |
else: | |
raise ValueError(f"Unknown model type {model_type}, exiting!") | |
def create(model_path: str, | |
model_type: str = "smpl", | |
**kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: | |
"""Method for creating a model from a path and a model type | |
Parameters | |
---------- | |
model_path: str | |
Either the path to the model you wish to load or a folder, | |
where each subfolder contains the differents types, i.e.: | |
model_path: | |
| | |
|-- smpl | |
|-- SMPL_FEMALE | |
|-- SMPL_NEUTRAL | |
|-- SMPL_MALE | |
|-- smplh | |
|-- SMPLH_FEMALE | |
|-- SMPLH_MALE | |
|-- smplx | |
|-- SMPLX_FEMALE | |
|-- SMPLX_NEUTRAL | |
|-- SMPLX_MALE | |
|-- mano | |
|-- MANO RIGHT | |
|-- MANO LEFT | |
model_type: str, optional | |
When model_path is a folder, then this parameter specifies the | |
type of model to be loaded | |
**kwargs: dict | |
Keyword arguments | |
Returns | |
------- | |
body_model: nn.Module | |
The PyTorch module that implements the corresponding body model | |
Raises | |
------ | |
ValueError: In case the model type is not one of SMPL, SMPLH, | |
SMPLX, MANO or FLAME | |
""" | |
# If it's a folder, assume | |
if osp.isdir(model_path): | |
model_path = os.path.join(model_path, model_type) | |
else: | |
model_type = osp.basename(model_path).split("_")[0].lower() | |
if model_type.lower() == "smpl": | |
return SMPL(model_path, **kwargs) | |
elif model_type.lower() == "smplh": | |
return SMPLH(model_path, **kwargs) | |
elif model_type.lower() == "smplx": | |
return SMPLX(model_path, **kwargs) | |
elif "mano" in model_type.lower(): | |
return MANO(model_path, **kwargs) | |
elif "flame" in model_type.lower(): | |
return FLAME(model_path, **kwargs) | |
else: | |
raise ValueError(f"Unknown model type {model_type}, exiting!") | |