LMM / app.py
mingyuan's picture
initial commit
373af33
raw
history blame
7.07 kB
import os
import sys
import gradio as gr
import time
os.makedirs("outputs", exist_ok=True)
sys.path.insert(0, '.')
import argparse
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmcv.parallel import MMDataParallel
from scipy.ndimage import gaussian_filter
from IPython.display import Image
from mogen.models.utils.imagebind_wrapper import (
extract_text_feature,
extract_audio_feature,
imagebind_huge
)
from mogen.models import build_architecture
from mogen.utils.plot_utils import (
plot_3d_motion,
add_audio,
get_audio_length
)
from mogen.datasets.paramUtil import (
t2m_body_hand_kinematic_chain,
t2m_kinematic_chain
)
from mogen.datasets.utils import recover_from_ric
from mogen.datasets.pipelines import RetargetSkeleton
def motion_temporal_filter(motion, sigma=1):
motion = motion.reshape(motion.shape[0], -1)
for i in range(motion.shape[1]):
motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
return motion.reshape(motion.shape[0], -1, 3)
def plot_tomato(data, kinematic_chain, result_path, npy_path, fps, sigma=None):
joints = recover_from_ric(torch.from_numpy(data).float(), 52).numpy()
joints = motion_temporal_filter(joints, sigma=2.5)
joints = rtg_skl({"keypoints3d": joints, "meta_data": {"has_lhnd": True}})["keypoints3d"]
plot_3d_motion(
out_path=result_path,
joints=joints,
kinematic_chain=kinematic_chain,
title=None,
fps=fps)
if npy_path is not None:
np.save(npy_path, joints)
def create_lmm():
config_path = "configs/lmm/lmm_small_demo.py"
ckpt_path = "pretrained/lmm_small_demo.pth"
cfg = mmcv.Config.fromfile(config_path)
model = build_architecture(cfg.model)
load_checkpoint(model, ckpt_path, map_location='cpu')
if device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=[0])
model.eval()
return model
# device = 'cpu'
device = 'cuda'
# os.environ["NO_PROXY"] = os.environ["no_proxy"] = "localhost, 127.0.0.1:7860"
model_lmm = create_lmm()
model_imagebind = imagebind_huge(pretrained=True)
model_imagebind.eval()
model_imagebind.to(device)
rtg_skl = RetargetSkeleton(tgt_skel_file='data/motionverse/statistics/skeleton.npy')
mean_path = "data/mean.npy"
std_path = "data/std.npy"
mean = np.load(mean_path)
std = np.load(std_path)
def show_generation_result(model, text, audio_path, motion_length, result_path):
fps = 20
if audio_path is not None:
motion_length = min(200, int(get_audio_length(audio_path) * fps) + 1)
motion = torch.zeros(1, motion_length, 669).to(device)
motion_mask = torch.ones(1, motion_length).to(device)
motion_mask[0, :motion_length] = 1
motion_mask = motion_mask.unsqueeze(-1).repeat(1, 1, 10)
motion_mask[:, :, 9] = 0
dataset_name = "humanml3d_t2m"
kinematic_chain = t2m_body_hand_kinematic_chain
rotation_type = "h3d_rot"
motion_metas = [{
'meta_data': dict(framerate=fps, dataset_name=dataset_name, rotation_type=rotation_type)
}]
motion_length = torch.Tensor([motion_length]).long().to(device)
if text is None and audio_path is not None:
text = "A person is standing and speaking."
model = model.to(device)
input = {
'motion': motion,
'motion_mask': motion_mask,
'motion_length': motion_length,
'motion_metas': motion_metas,
'num_intervals': 1
}
if text is not None:
text_word_feat, text_seq_feat = \
extract_text_feature([text], model_imagebind, device)
assert text_word_feat.shape[0] == 1
assert text_word_feat.shape[1] == 77
assert text_word_feat.shape[2] == 1024
assert text_seq_feat.shape[0] == 1
assert text_seq_feat.shape[1] == 1024
input['text_word_feat'] = text_word_feat
input['text_seq_feat'] = text_seq_feat
input['text_cond'] = torch.Tensor([1.0] * 1).to(device)
else:
input['text_word_feat'] = torch.zeros(1, 77, 1024).to(device)
input['text_seq_feat'] = torch.zeros(1, 1024)
input['text_cond'] = torch.Tensor([0] * 1).to(device)
if audio_path is not None:
speech_word_feat, speech_seq_feat = \
extract_audio_feature([audio_path], model_imagebind, device)
assert speech_word_feat.shape[0] == 1
assert speech_word_feat.shape[1] == 229
assert speech_word_feat.shape[2] == 768
assert speech_seq_feat.shape[0] == 1
assert speech_seq_feat.shape[1] == 1024
input['speech_word_feat'] = speech_word_feat
input['speech_seq_feat'] = speech_seq_feat
input['speech_cond'] = torch.Tensor([1.0] * 1).to(device)
else:
input['speech_word_feat'] = torch.zeros(1, 229, 768).to(device)
input['speech_seq_feat'] = torch.zeros(1, 1024)
input['speech_cond'] = torch.Tensor([0] * 1).to(device)
all_pred_motion = []
with torch.no_grad():
input['inference_kwargs'] = {}
output = model(**input)[0]['pred_motion'][:motion_length]
pred_motion = output.cpu().detach().numpy()
pred_motion = pred_motion * std + mean
plot_tomato(pred_motion, kinematic_chain, result_path, None, fps, 2)
if audio_path is not None:
add_audio(result_path, [audio_path])
def generate(prompt, audio_path, length):
if not os.path.exists("outputs"):
os.mkdir("outputs")
result_path = "outputs/" + str(int(time.time())) + ".mp4"
print(audio_path)
if audio_path.endswith("placeholder.wav"):
audio_path = None
if len(prompt) == 0:
prompt = None
show_generation_result(model_lmm, prompt, audio_path, length, result_path)
return result_path
input_audio = gr.Audio(
type='filepath',
format='wav',
label="Audio (1-10s, overwrite motion length):",
show_label=True,
sources=["upload", "microphone"],
min_length=1,
max_length=10,
waveform_options=gr.WaveformOptions(
waveform_color="#01C6FF",
waveform_progress_color="#0066B4",
skip_length=2,
show_controls=False,
),
)
input_text = gr.Textbox(
label="Text prompt:"
)
demo = gr.Interface(
fn=generate,
inputs=[input_text, input_audio, gr.Slider(20, 200, value=60, label="Motion length (fps 20):")],
outputs=gr.Video(label="Video:"),
examples=[
["A person walks in a circle.", "examples/placeholder.m4a", 120],
["A person jumps forward.", "examples/placeholder.m4a", 100],
["A person is stretching arms.", "examples/placeholder.m4a", 80],
["", "examples/surprise.m4a", 200],
["", "examples/angry.m4a", 200],
],
title="LMM: Large Motion Model for Unified Multi-Modal Motion Generation",
description="\nThis is an interactive demo for LMM. For more information, feel free to visit our project page(https://github.com/mingyuan-zhang/LMM).")
demo.queue()
demo.launch()