import os import sys import gradio as gr import time os.makedirs("outputs", exist_ok=True) sys.path.insert(0, '.') import argparse import os.path as osp import mmcv import numpy as np import torch from mmcv.runner import load_checkpoint from mmcv.parallel import MMDataParallel from scipy.ndimage import gaussian_filter from IPython.display import Image from mogen.models.utils.imagebind_wrapper import ( extract_text_feature, extract_audio_feature, imagebind_huge ) from mogen.models import build_architecture from mogen.utils.plot_utils import ( plot_3d_motion, add_audio, get_audio_length ) from mogen.datasets.paramUtil import ( t2m_body_hand_kinematic_chain, t2m_kinematic_chain ) from mogen.datasets.utils import recover_from_ric from mogen.datasets.pipelines import RetargetSkeleton def motion_temporal_filter(motion, sigma=1): motion = motion.reshape(motion.shape[0], -1) for i in range(motion.shape[1]): motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") return motion.reshape(motion.shape[0], -1, 3) def plot_tomato(data, kinematic_chain, result_path, npy_path, fps, sigma=None): joints = recover_from_ric(torch.from_numpy(data).float(), 52).numpy() joints = motion_temporal_filter(joints, sigma=2.5) joints = rtg_skl({"keypoints3d": joints, "meta_data": {"has_lhnd": True}})["keypoints3d"] plot_3d_motion( out_path=result_path, joints=joints, kinematic_chain=kinematic_chain, title=None, fps=fps) if npy_path is not None: np.save(npy_path, joints) def create_lmm(): config_path = "configs/lmm/lmm_small_demo.py" ckpt_path = "pretrained/lmm_small_demo.pth" cfg = mmcv.Config.fromfile(config_path) model = build_architecture(cfg.model) load_checkpoint(model, ckpt_path, map_location='cpu') if device == 'cpu': model = model.cpu() else: model = MMDataParallel(model, device_ids=[0]) model.eval() return model # device = 'cpu' device = 'cuda' # os.environ["NO_PROXY"] = os.environ["no_proxy"] = "localhost, 127.0.0.1:7860" model_lmm = create_lmm() model_imagebind = imagebind_huge(pretrained=True) model_imagebind.eval() model_imagebind.to(device) rtg_skl = RetargetSkeleton(tgt_skel_file='data/motionverse/statistics/skeleton.npy') mean_path = "data/mean.npy" std_path = "data/std.npy" mean = np.load(mean_path) std = np.load(std_path) def show_generation_result(model, text, audio_path, motion_length, result_path): fps = 20 if audio_path is not None: motion_length = min(200, int(get_audio_length(audio_path) * fps) + 1) motion = torch.zeros(1, motion_length, 669).to(device) motion_mask = torch.ones(1, motion_length).to(device) motion_mask[0, :motion_length] = 1 motion_mask = motion_mask.unsqueeze(-1).repeat(1, 1, 10) motion_mask[:, :, 9] = 0 dataset_name = "humanml3d_t2m" kinematic_chain = t2m_body_hand_kinematic_chain rotation_type = "h3d_rot" motion_metas = [{ 'meta_data': dict(framerate=fps, dataset_name=dataset_name, rotation_type=rotation_type) }] motion_length = torch.Tensor([motion_length]).long().to(device) if text is None and audio_path is not None: text = "A person is standing and speaking." model = model.to(device) input = { 'motion': motion, 'motion_mask': motion_mask, 'motion_length': motion_length, 'motion_metas': motion_metas, 'num_intervals': 1 } if text is not None: text_word_feat, text_seq_feat = \ extract_text_feature([text], model_imagebind, device) assert text_word_feat.shape[0] == 1 assert text_word_feat.shape[1] == 77 assert text_word_feat.shape[2] == 1024 assert text_seq_feat.shape[0] == 1 assert text_seq_feat.shape[1] == 1024 input['text_word_feat'] = text_word_feat input['text_seq_feat'] = text_seq_feat input['text_cond'] = torch.Tensor([1.0] * 1).to(device) else: input['text_word_feat'] = torch.zeros(1, 77, 1024).to(device) input['text_seq_feat'] = torch.zeros(1, 1024) input['text_cond'] = torch.Tensor([0] * 1).to(device) if audio_path is not None: speech_word_feat, speech_seq_feat = \ extract_audio_feature([audio_path], model_imagebind, device) assert speech_word_feat.shape[0] == 1 assert speech_word_feat.shape[1] == 229 assert speech_word_feat.shape[2] == 768 assert speech_seq_feat.shape[0] == 1 assert speech_seq_feat.shape[1] == 1024 input['speech_word_feat'] = speech_word_feat input['speech_seq_feat'] = speech_seq_feat input['speech_cond'] = torch.Tensor([1.0] * 1).to(device) else: input['speech_word_feat'] = torch.zeros(1, 229, 768).to(device) input['speech_seq_feat'] = torch.zeros(1, 1024) input['speech_cond'] = torch.Tensor([0] * 1).to(device) all_pred_motion = [] with torch.no_grad(): input['inference_kwargs'] = {} output = model(**input)[0]['pred_motion'][:motion_length] pred_motion = output.cpu().detach().numpy() pred_motion = pred_motion * std + mean plot_tomato(pred_motion, kinematic_chain, result_path, None, fps, 2) if audio_path is not None: add_audio(result_path, [audio_path]) def generate(prompt, audio_path, length): if not os.path.exists("outputs"): os.mkdir("outputs") result_path = "outputs/" + str(int(time.time())) + ".mp4" print(audio_path) if audio_path.endswith("placeholder.wav"): audio_path = None if len(prompt) == 0: prompt = None show_generation_result(model_lmm, prompt, audio_path, length, result_path) return result_path input_audio = gr.Audio( type='filepath', format='wav', label="Audio (1-10s, overwrite motion length):", show_label=True, sources=["upload", "microphone"], min_length=1, max_length=10, waveform_options=gr.WaveformOptions( waveform_color="#01C6FF", waveform_progress_color="#0066B4", skip_length=2, show_controls=False, ), ) input_text = gr.Textbox( label="Text prompt:" ) demo = gr.Interface( fn=generate, inputs=[input_text, input_audio, gr.Slider(20, 200, value=60, label="Motion length (fps 20):")], outputs=gr.Video(label="Video:"), examples=[ ["A person walks in a circle.", "examples/placeholder.m4a", 120], ["A person jumps forward.", "examples/placeholder.m4a", 100], ["A person is stretching arms.", "examples/placeholder.m4a", 80], ["", "examples/surprise.m4a", 200], ["", "examples/angry.m4a", 200], ], title="LMM: Large Motion Model for Unified Multi-Modal Motion Generation", description="\nThis is an interactive demo for LMM. For more information, feel free to visit our project page(https://github.com/mingyuan-zhang/LMM).") demo.queue() demo.launch()