Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# Using this computer program means that you agree to the terms | |
# in the LICENSE file included with this software distribution. | |
# Any use not explicitly granted by the LICENSE is prohibited. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# For comments or questions, please email us at [email protected] | |
# For commercial licensing contact, please contact [email protected] | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
from skimage.io import imread | |
from .models.encoders import MLP, HRNEncoder, ResnetEncoder | |
from .models.moderators import TempSoftmaxFusion | |
from .models.SMPLX import SMPLX | |
from .utils import rotation_converter as converter | |
from .utils import tensor_cropper, util | |
from .utils.config import cfg | |
class PIXIE(object): | |
def __init__(self, config=None, device="cuda:0"): | |
if config is None: | |
self.cfg = cfg | |
else: | |
self.cfg = config | |
self.device = device | |
# parameters setting | |
self.param_list_dict = {} | |
for lst in self.cfg.params.keys(): | |
param_list = cfg.params.get(lst) | |
self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list} | |
# Build the models | |
self._create_model() | |
# Set up the cropping modules used to generate face/hand crops from the body predictions | |
self._setup_cropper() | |
def forward(self, data): | |
# encode + decode | |
param_dict = self.encode( | |
{"body": {"image": data}}, | |
threthold=True, | |
keep_local=True, | |
copy_and_paste=False, | |
) | |
opdict = self.decode(param_dict["body"], param_type="body") | |
return opdict | |
def _setup_cropper(self): | |
self.Cropper = {} | |
for crop_part in ["head", "hand"]: | |
data_cfg = self.cfg.dataset[crop_part] | |
scale_size = (data_cfg.scale_min + data_cfg.scale_max) * 0.5 | |
self.Cropper[crop_part] = tensor_cropper.Cropper( | |
crop_size=data_cfg.image_size, | |
scale=[scale_size, scale_size], | |
trans_scale=0, | |
) | |
def _create_model(self): | |
self.model_dict = {} | |
# Build all image encoders | |
# Hand encoder only works for right hand, for left hand, flip inputs and flip the results back | |
self.Encoder = {} | |
for key in self.cfg.network.encoder.keys(): | |
if self.cfg.network.encoder.get(key).type == "resnet50": | |
self.Encoder[key] = ResnetEncoder().to(self.device) | |
elif self.cfg.network.encoder.get(key).type == "hrnet": | |
self.Encoder[key] = HRNEncoder().to(self.device) | |
self.model_dict[f"Encoder_{key}"] = self.Encoder[key].state_dict() | |
# Build the parameter regressors | |
self.Regressor = {} | |
for key in self.cfg.network.regressor.keys(): | |
n_output = sum(self.param_list_dict[f"{key}_list"].values()) | |
channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output]) | |
if self.cfg.network.regressor.get(key).type == "mlp": | |
self.Regressor[key] = MLP(channels=channels).to(self.device) | |
self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict() | |
# Build the extractors | |
# to extract separate head/left hand/right hand feature from body feature | |
self.Extractor = {} | |
for key in self.cfg.network.extractor.keys(): | |
channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048] | |
if self.cfg.network.extractor.get(key).type == "mlp": | |
self.Extractor[key] = MLP(channels=channels).to(self.device) | |
self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict() | |
# Build the moderators | |
self.Moderator = {} | |
for key in self.cfg.network.moderator.keys(): | |
share_part = key.split("_")[0] | |
detach_inputs = self.cfg.network.moderator.get(key).detach_inputs | |
detach_feature = self.cfg.network.moderator.get(key).detach_feature | |
channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2] | |
self.Moderator[key] = TempSoftmaxFusion( | |
detach_inputs=detach_inputs, | |
detach_feature=detach_feature, | |
channels=channels, | |
).to(self.device) | |
self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict() | |
# Build the SMPL-X body model, which we also use to represent faces and | |
# hands, using the relevant parts only | |
self.smplx = SMPLX(self.cfg.model).to(self.device) | |
self.part_indices = self.smplx.part_indices | |
# -- resume model | |
model_path = self.cfg.pretrained_modelpath | |
if os.path.exists(model_path): | |
checkpoint = torch.load(model_path) | |
for key in self.model_dict.keys(): | |
util.copy_state_dict(self.model_dict[key], checkpoint[key]) | |
else: | |
print(f"pixie trained model path: {model_path} does not exist!") | |
exit() | |
# eval mode | |
for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]: | |
for net in module.values(): | |
net.eval() | |
def decompose_code(self, code, num_dict): | |
"""Convert a flattened parameter vector to a dictionary of parameters""" | |
code_dict = {} | |
start = 0 | |
for key in num_dict: | |
end = start + int(num_dict[key]) | |
code_dict[key] = code[:, start:end] | |
start = end | |
return code_dict | |
def part_from_body(self, image, part_key, points_dict, crop_joints=None): | |
"""crop part(head/left_hand/right_hand) out from body data, joints also change accordingly""" | |
assert part_key in ["head", "left_hand", "right_hand"] | |
assert "smplx_kpt" in points_dict.keys() | |
if part_key == "head": | |
# use face 68 kpts for cropping head image | |
indices_key = "face" | |
elif part_key == "left_hand": | |
indices_key = "left_hand" | |
elif part_key == "right_hand": | |
indices_key = "right_hand" | |
# get points for cropping | |
part_indices = self.part_indices[indices_key] | |
if crop_joints is not None: | |
points_for_crop = crop_joints[:, part_indices] | |
else: | |
points_for_crop = points_dict["smplx_kpt"][:, part_indices] | |
# crop | |
cropper_key = "hand" if "hand" in part_key else part_key | |
points_scale = image.shape[-2:] | |
cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale) | |
# transform points(must be normalized to [-1.1]) accordingly | |
cropped_points_dict = {} | |
for points_key in points_dict.keys(): | |
points = points_dict[points_key] | |
cropped_points = self.Cropper[cropper_key].transform_points( | |
points, tform, points_scale, normalize=True | |
) | |
cropped_points_dict[points_key] = cropped_points | |
return cropped_image, cropped_points_dict | |
def encode( | |
self, | |
data, | |
threthold=True, | |
keep_local=True, | |
copy_and_paste=False, | |
body_only=False, | |
): | |
"""Encode images to smplx parameters | |
Args: | |
data: dict | |
key: image_type (body/head/hand) | |
value: | |
image: [bz, 3, 224, 224], range [0,1] | |
image_hd(needed if key==body): a high res version of image, only for cropping parts from body image | |
head_image: optinal, well-cropped head from body image | |
left_hand_image: optinal, well-cropped left hand from body image | |
right_hand_image: optinal, well-cropped right hand from body image | |
Returns: | |
param_dict: dict | |
key: image_type (body/head/hand) | |
value: param_dict | |
""" | |
for key in data.keys(): | |
assert key in ["body", "head", "hand"] | |
feature = {} | |
param_dict = {} | |
# Encode features | |
for key in data.keys(): | |
part = key | |
# encode feature | |
feature[key] = {} | |
feature[key][part] = self.Encoder[part](data[key]["image"]) | |
# for head/hand image | |
if key == "head" or key == "hand": | |
# predict head/hand-only parameters from part feature | |
part_dict = self.decompose_code( | |
self.Regressor[part](feature[key][part]), | |
self.param_list_dict[f"{part}_list"], | |
) | |
# if input is part data, skip feature fusion: share feature is the same as part feature | |
# then predict share parameters | |
feature[key][f"{key}_share"] = feature[key][key] | |
share_dict = self.decompose_code( | |
self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]), | |
self.param_list_dict[f"{part}_share_list"], | |
) | |
# compose parameters | |
param_dict[key] = {**share_dict, **part_dict} | |
# for body image | |
if key == "body": | |
fusion_weight = {} | |
f_body = feature["body"]["body"] | |
# extract part feature | |
for part_name in ["head", "left_hand", "right_hand"]: | |
feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"]( | |
f_body | |
) | |
# -- check if part crops are given, if not, crop parts by coarse body estimation | |
if ( | |
"head_image" not in data[key].keys() or | |
"left_hand_image" not in data[key].keys() or | |
"right_hand_image" not in data[key].keys() | |
): | |
# - run without fusion to get coarse estimation, for cropping parts | |
# body only | |
body_dict = self.decompose_code( | |
self.Regressor[part](feature[key][part]), | |
self.param_list_dict[part + "_list"], | |
) | |
# head share | |
head_share_dict = self.decompose_code( | |
self.Regressor["head" + "_share"](feature[key]["head" + "_share"]), | |
self.param_list_dict["head" + "_share_list"], | |
) | |
# right hand share | |
right_hand_share_dict = self.decompose_code( | |
self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]), | |
self.param_list_dict["hand" + "_share_list"], | |
) | |
# left hand share | |
left_hand_share_dict = self.decompose_code( | |
self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]), | |
self.param_list_dict["hand" + "_share_list"], | |
) | |
# change the dict name from right to left | |
left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop( | |
"right_hand_pose" | |
) | |
left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop( | |
"right_wrist_pose" | |
) | |
param_dict[key] = { | |
**body_dict, | |
**head_share_dict, | |
**left_hand_share_dict, | |
**right_hand_share_dict, | |
} | |
if body_only: | |
param_dict["moderator_weight"] = None | |
return param_dict | |
prediction_body_only = self.decode(param_dict[key], param_type="body") | |
# crop | |
for part_name in ["head", "left_hand", "right_hand"]: | |
part = part_name.split("_")[-1] | |
points_dict = { | |
"smplx_kpt": prediction_body_only["smplx_kpt"], | |
"trans_verts": prediction_body_only["transformed_vertices"], | |
} | |
image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"]) | |
cropped_image, cropped_joints_dict = self.part_from_body( | |
image_hd, part_name, points_dict | |
) | |
data[key][part_name + "_image"] = cropped_image | |
# -- encode features from part crops, then fuse feature using the weight from moderator | |
for part_name in ["head", "left_hand", "right_hand"]: | |
part = part_name.split("_")[-1] | |
cropped_image = data[key][part_name + "_image"] | |
# if left hand, flip it as if it is right hand | |
if part_name == "left_hand": | |
cropped_image = torch.flip(cropped_image, dims=(-1, )) | |
# run part regressor | |
f_part = self.Encoder[part](cropped_image) | |
part_dict = self.decompose_code( | |
self.Regressor[part](f_part), | |
self.param_list_dict[f"{part}_list"], | |
) | |
part_share_dict = self.decompose_code( | |
self.Regressor[f"{part}_share"](f_part), | |
self.param_list_dict[f"{part}_share_list"], | |
) | |
param_dict["body_" + part_name] = {**part_dict, **part_share_dict} | |
# moderator to assign weight, then integrate features | |
f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"]( | |
feature["body"][f"{part_name}_share"], f_part, work=True | |
) | |
if copy_and_paste: | |
# copy and paste strategy always trusts the results from part | |
feature["body"][f"{part_name}_share"] = f_part | |
elif threthold and part == "hand": | |
# for hand, if part weight > 0.7 (very confident, then fully trust part) | |
part_w = f_weight[:, [1]] | |
part_w[part_w > 0.7] = 1.0 | |
f_body_out = ( | |
feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w | |
) | |
feature["body"][f"{part_name}_share"] = f_body_out | |
else: | |
feature["body"][f"{part_name}_share"] = f_body_out | |
fusion_weight[part_name] = f_weight | |
# save weights from moderator, that can be further used for optimization/running specific tasks on parts | |
param_dict["moderator_weight"] = fusion_weight | |
# -- predict parameters from fused body feature | |
# head share | |
head_share_dict = self.decompose_code( | |
self.Regressor["head" + "_share"](feature[key]["head" + "_share"]), | |
self.param_list_dict["head" + "_share_list"], | |
) | |
# right hand share | |
right_hand_share_dict = self.decompose_code( | |
self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]), | |
self.param_list_dict["hand" + "_share_list"], | |
) | |
# left hand share | |
left_hand_share_dict = self.decompose_code( | |
self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]), | |
self.param_list_dict["hand" + "_share_list"], | |
) | |
# change the dict name from right to left | |
left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose") | |
left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop( | |
"right_wrist_pose" | |
) | |
param_dict["body"] = { | |
**body_dict, | |
**head_share_dict, | |
**left_hand_share_dict, | |
**right_hand_share_dict, | |
} | |
# copy tex param from head param dict to body param dict | |
param_dict["body"]["tex"] = param_dict["body_head"]["tex"] | |
param_dict["body"]["light"] = param_dict["body_head"]["light"] | |
if keep_local: | |
# for local change that will not affect whole body and produce unnatral pose, trust part | |
param_dict[key]["exp"] = param_dict["body_head"]["exp"] | |
param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][ | |
"right_hand_pose"] | |
param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][ | |
"right_hand_pose"] | |
return param_dict | |
def convert_pose(self, param_dict, param_type): | |
"""Convert pose parameters to rotation matrix | |
Args: | |
param_dict: smplx parameters | |
param_type: should be one of body/head/hand | |
Returns: | |
param_dict: smplx parameters | |
""" | |
assert param_type in ["body", "head", "hand"] | |
# convert pose representations: the output from network are continous repre or axis angle, | |
# while the input pose for smplx need to be rotation matrix | |
for key in param_dict: | |
if "pose" in key and "jaw" not in key: | |
param_dict[key] = converter.batch_cont2matrix(param_dict[key]) | |
if param_type == "body" or param_type == "head": | |
param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"] | |
)[:, None, :, :] | |
# complement params if it's not in given param dict | |
if param_type == "head": | |
batch_size = param_dict["shape"].shape[0] | |
param_dict["abs_head_pose"] = param_dict["head_pose"].clone() | |
param_dict["global_pose"] = param_dict["head_pose"] | |
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
)[:, :self.param_list_dict["body_list"]["partbody_pose"]] | |
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
elif param_type == "hand": | |
batch_size = param_dict["right_hand_pose"].shape[0] | |
param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone() | |
dtype = param_dict["right_hand_pose"].dtype | |
device = param_dict["right_hand_pose"].device | |
x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1)) | |
x_180_pose[0, 2, 2] = -1.0 | |
x_180_pose[0, 1, 1] = -1.0 | |
param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1) | |
param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1) | |
param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
)[:, :self.param_list_dict["body_list"]["partbody_pose"]] | |
param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand( | |
batch_size, -1, -1, -1 | |
) | |
elif param_type == "body": | |
# the predcition from the head and hand share regressor is always absolute pose | |
batch_size = param_dict["shape"].shape[0] | |
param_dict["abs_head_pose"] = param_dict["head_pose"].clone() | |
param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone() | |
param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone() | |
# the body-hand share regressor is working for right hand | |
# so we assume body network get the flipped feature for the left hand. then get the parameters | |
# then we need to flip it back to left, which matches the input left hand | |
param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"]) | |
param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"]) | |
else: | |
exit() | |
return param_dict | |
def decode(self, param_dict, param_type): | |
"""Decode model parameters to smplx vertices & joints & texture | |
Args: | |
param_dict: smplx parameters | |
param_type: should be one of body/head/hand | |
Returns: | |
predictions: smplx predictions | |
""" | |
if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2: | |
self.convert_pose(param_dict, param_type) | |
elif param_dict["right_wrist_pose"].shape[-1] == 6: | |
self.convert_pose(param_dict, param_type) | |
# concatenate body pose | |
partbody_pose = param_dict["partbody_pose"] | |
param_dict["body_pose"] = torch.cat( | |
[ | |
partbody_pose[:, :11], | |
param_dict["neck_pose"], | |
partbody_pose[:, 11:11 + 2], | |
param_dict["head_pose"], | |
partbody_pose[:, 13:13 + 4], | |
param_dict["left_wrist_pose"], | |
param_dict["right_wrist_pose"], | |
], | |
dim=1, | |
) | |
# change absolute head&hand pose to relative pose according to rest body pose | |
if param_type == "head" or param_type == "body": | |
param_dict["body_pose"] = self.smplx.pose_abs2rel( | |
param_dict["global_pose"], param_dict["body_pose"], abs_joint="head" | |
) | |
if param_type == "hand" or param_type == "body": | |
param_dict["body_pose"] = self.smplx.pose_abs2rel( | |
param_dict["global_pose"], | |
param_dict["body_pose"], | |
abs_joint="left_wrist", | |
) | |
param_dict["body_pose"] = self.smplx.pose_abs2rel( | |
param_dict["global_pose"], | |
param_dict["body_pose"], | |
abs_joint="right_wrist", | |
) | |
if self.cfg.model.check_pose: | |
# check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose) | |
# xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left) | |
for pose_ind in [14]: # head [15-1, 20-1, 21-1]: | |
curr_pose = param_dict["body_pose"][:, pose_ind] | |
euler_pose = converter._compute_euler_from_matrix(curr_pose) | |
for i, max_angle in enumerate([20, 70, 10]): | |
euler_pose_curr = euler_pose[:, i] | |
euler_pose_curr[euler_pose_curr != torch.clamp( | |
euler_pose_curr, | |
min=-max_angle * np.pi / 180, | |
max=max_angle * np.pi / 180, | |
)] = 0.0 | |
param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose) | |
# SMPLX | |
verts, landmarks, joints = self.smplx( | |
shape_params=param_dict["shape"], | |
expression_params=param_dict["exp"], | |
global_pose=param_dict["global_pose"], | |
body_pose=param_dict["body_pose"], | |
jaw_pose=param_dict["jaw_pose"], | |
left_hand_pose=param_dict["left_hand_pose"], | |
right_hand_pose=param_dict["right_hand_pose"], | |
) | |
smplx_kpt3d = joints.clone() | |
# projection | |
cam = param_dict[param_type + "_cam"] | |
trans_verts = util.batch_orth_proj(verts, cam) | |
predicted_landmarks = util.batch_orth_proj(landmarks, cam)[:, :, :2] | |
predicted_joints = util.batch_orth_proj(joints, cam)[:, :, :2] | |
prediction = { | |
"vertices": verts, | |
"transformed_vertices": trans_verts, | |
"face_kpt": predicted_landmarks, | |
"smplx_kpt": predicted_joints, | |
"smplx_kpt3d": smplx_kpt3d, | |
"joints": joints, | |
"cam": param_dict[param_type + "_cam"], | |
} | |
# change the order of face keypoints, to be the same as "standard" 68 keypoints | |
prediction["face_kpt"] = torch.cat([ | |
prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17] | |
], | |
dim=1) | |
prediction.update(param_dict) | |
return prediction | |
def decode_Tpose(self, param_dict): | |
"""return body mesh in T pose, support body and head param dict only""" | |
verts, _, _ = self.smplx( | |
shape_params=param_dict["shape"], | |
expression_params=param_dict["exp"], | |
jaw_pose=param_dict["jaw_pose"], | |
) | |
return verts | |