|
import os |
|
import sys |
|
import gradio as gr |
|
import time |
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
sys.path.insert(0, '.') |
|
|
|
import argparse |
|
import os.path as osp |
|
import mmcv |
|
import numpy as np |
|
import torch |
|
from mmcv.runner import load_checkpoint |
|
from mmcv.parallel import MMDataParallel |
|
from scipy.ndimage import gaussian_filter |
|
from IPython.display import Image |
|
|
|
from mogen.models.utils.imagebind_wrapper import ( |
|
extract_text_feature, |
|
extract_audio_feature, |
|
imagebind_huge |
|
) |
|
from mogen.models import build_architecture |
|
|
|
from mogen.utils.plot_utils import ( |
|
plot_3d_motion, |
|
add_audio, |
|
get_audio_length |
|
) |
|
from mogen.datasets.paramUtil import ( |
|
t2m_body_hand_kinematic_chain, |
|
t2m_kinematic_chain |
|
) |
|
from mogen.datasets.utils import recover_from_ric |
|
from mogen.datasets.pipelines import RetargetSkeleton |
|
|
|
|
|
def motion_temporal_filter(motion, sigma=1): |
|
motion = motion.reshape(motion.shape[0], -1) |
|
for i in range(motion.shape[1]): |
|
motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") |
|
return motion.reshape(motion.shape[0], -1, 3) |
|
|
|
def plot_tomato(data, kinematic_chain, result_path, npy_path, fps, sigma=None): |
|
joints = recover_from_ric(torch.from_numpy(data).float(), 52).numpy() |
|
joints = motion_temporal_filter(joints, sigma=2.5) |
|
joints = rtg_skl({"keypoints3d": joints, "meta_data": {"has_lhnd": True}})["keypoints3d"] |
|
plot_3d_motion( |
|
out_path=result_path, |
|
joints=joints, |
|
kinematic_chain=kinematic_chain, |
|
title=None, |
|
fps=fps) |
|
if npy_path is not None: |
|
np.save(npy_path, joints) |
|
|
|
def create_lmm(): |
|
config_path = "configs/lmm/lmm_small_demo.py" |
|
ckpt_path = "pretrained/lmm_small_demo.pth" |
|
cfg = mmcv.Config.fromfile(config_path) |
|
model = build_architecture(cfg.model) |
|
load_checkpoint(model, ckpt_path, map_location='cpu') |
|
if device == 'cpu': |
|
model = model.cpu() |
|
else: |
|
model = MMDataParallel(model, device_ids=[0]) |
|
model.eval() |
|
return model |
|
|
|
|
|
device = 'cuda' |
|
|
|
model_lmm = create_lmm() |
|
model_imagebind = imagebind_huge(pretrained=True) |
|
model_imagebind.eval() |
|
model_imagebind.to(device) |
|
rtg_skl = RetargetSkeleton(tgt_skel_file='data/motionverse/statistics/skeleton.npy') |
|
|
|
mean_path = "data/mean.npy" |
|
std_path = "data/std.npy" |
|
mean = np.load(mean_path) |
|
std = np.load(std_path) |
|
|
|
def show_generation_result(model, text, audio_path, motion_length, result_path): |
|
fps = 20 |
|
if audio_path is not None: |
|
motion_length = min(200, int(get_audio_length(audio_path) * fps) + 1) |
|
motion = torch.zeros(1, motion_length, 669).to(device) |
|
motion_mask = torch.ones(1, motion_length).to(device) |
|
motion_mask[0, :motion_length] = 1 |
|
motion_mask = motion_mask.unsqueeze(-1).repeat(1, 1, 10) |
|
motion_mask[:, :, 9] = 0 |
|
dataset_name = "humanml3d_t2m" |
|
kinematic_chain = t2m_body_hand_kinematic_chain |
|
rotation_type = "h3d_rot" |
|
motion_metas = [{ |
|
'meta_data': dict(framerate=fps, dataset_name=dataset_name, rotation_type=rotation_type) |
|
}] |
|
motion_length = torch.Tensor([motion_length]).long().to(device) |
|
if text is None and audio_path is not None: |
|
text = "A person is standing and speaking." |
|
|
|
model = model.to(device) |
|
input = { |
|
'motion': motion, |
|
'motion_mask': motion_mask, |
|
'motion_length': motion_length, |
|
'motion_metas': motion_metas, |
|
'num_intervals': 1 |
|
} |
|
if text is not None: |
|
text_word_feat, text_seq_feat = \ |
|
extract_text_feature([text], model_imagebind, device) |
|
assert text_word_feat.shape[0] == 1 |
|
assert text_word_feat.shape[1] == 77 |
|
assert text_word_feat.shape[2] == 1024 |
|
assert text_seq_feat.shape[0] == 1 |
|
assert text_seq_feat.shape[1] == 1024 |
|
input['text_word_feat'] = text_word_feat |
|
input['text_seq_feat'] = text_seq_feat |
|
input['text_cond'] = torch.Tensor([1.0] * 1).to(device) |
|
else: |
|
input['text_word_feat'] = torch.zeros(1, 77, 1024).to(device) |
|
input['text_seq_feat'] = torch.zeros(1, 1024) |
|
input['text_cond'] = torch.Tensor([0] * 1).to(device) |
|
if audio_path is not None: |
|
speech_word_feat, speech_seq_feat = \ |
|
extract_audio_feature([audio_path], model_imagebind, device) |
|
assert speech_word_feat.shape[0] == 1 |
|
assert speech_word_feat.shape[1] == 229 |
|
assert speech_word_feat.shape[2] == 768 |
|
assert speech_seq_feat.shape[0] == 1 |
|
assert speech_seq_feat.shape[1] == 1024 |
|
input['speech_word_feat'] = speech_word_feat |
|
input['speech_seq_feat'] = speech_seq_feat |
|
input['speech_cond'] = torch.Tensor([1.0] * 1).to(device) |
|
else: |
|
input['speech_word_feat'] = torch.zeros(1, 229, 768).to(device) |
|
input['speech_seq_feat'] = torch.zeros(1, 1024) |
|
input['speech_cond'] = torch.Tensor([0] * 1).to(device) |
|
|
|
all_pred_motion = [] |
|
with torch.no_grad(): |
|
input['inference_kwargs'] = {} |
|
output = model(**input)[0]['pred_motion'][:motion_length] |
|
pred_motion = output.cpu().detach().numpy() |
|
pred_motion = pred_motion * std + mean |
|
|
|
plot_tomato(pred_motion, kinematic_chain, result_path, None, fps, 2) |
|
|
|
if audio_path is not None: |
|
add_audio(result_path, [audio_path]) |
|
|
|
def generate(prompt, audio_path, length): |
|
if not os.path.exists("outputs"): |
|
os.mkdir("outputs") |
|
result_path = "outputs/" + str(int(time.time())) + ".mp4" |
|
print(audio_path) |
|
if audio_path.endswith("placeholder.wav"): |
|
audio_path = None |
|
if len(prompt) == 0: |
|
prompt = None |
|
show_generation_result(model_lmm, prompt, audio_path, length, result_path) |
|
return result_path |
|
|
|
input_audio = gr.Audio( |
|
type='filepath', |
|
format='wav', |
|
label="Audio (1-10s, overwrite motion length):", |
|
show_label=True, |
|
sources=["upload", "microphone"], |
|
min_length=1, |
|
max_length=10, |
|
waveform_options=gr.WaveformOptions( |
|
waveform_color="#01C6FF", |
|
waveform_progress_color="#0066B4", |
|
skip_length=2, |
|
show_controls=False, |
|
), |
|
) |
|
|
|
input_text = gr.Textbox( |
|
label="Text prompt:" |
|
) |
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=[input_text, input_audio, gr.Slider(20, 200, value=60, label="Motion length (fps 20):")], |
|
outputs=gr.Video(label="Video:"), |
|
examples=[ |
|
["A person walks in a circle.", "examples/placeholder.m4a", 120], |
|
["A person jumps forward.", "examples/placeholder.m4a", 100], |
|
["A person is stretching arms.", "examples/placeholder.m4a", 80], |
|
["", "examples/surprise.m4a", 200], |
|
["", "examples/angry.m4a", 200], |
|
], |
|
title="LMM: Large Motion Model for Unified Multi-Modal Motion Generation", |
|
description="\nThis is an interactive demo for LMM. For more information, feel free to visit our project page(https://github.com/mingyuan-zhang/LMM).") |
|
|
|
demo.queue() |
|
demo.launch() |