|
import torch |
|
|
|
print(torch.__version__) |
|
print(torch.version.cuda) |
|
print(torch.cuda.is_available()) |
|
|
|
import os, subprocess |
|
import uuid, tempfile |
|
from glob import glob |
|
|
|
env_list = os.environ['PATH'].split(':') |
|
env_list.append('/usr/local/cuda/bin') |
|
os.environ['PATH'] = ':'.join(env_list) |
|
os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6' |
|
|
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
|
|
os.makedirs("pretrained", exist_ok=True) |
|
snapshot_download( |
|
repo_id = "jiawei011/L4GM", |
|
local_dir = "./pretrained" |
|
) |
|
|
|
|
|
examples_folder = "data_test" |
|
|
|
|
|
video_examples = [ |
|
os.path.join(examples_folder, file) |
|
for file in os.listdir(examples_folder) |
|
if os.path.isfile(os.path.join(examples_folder, file)) |
|
] |
|
|
|
|
|
def generate(input_video): |
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
workdir = temp_dir |
|
recon_model = "pretrained/recon.safetensors" |
|
interp_model = "pretrained/interp.safetensors" |
|
num_frames = 16 |
|
test_path = input_video |
|
|
|
try: |
|
|
|
subprocess.run( |
|
[ |
|
"python", "infer_3d.py", "big", |
|
"--workspace", f"{workdir}", |
|
"--resume", f"{recon_model}", |
|
"--interpresume", f"{interp_model}", |
|
"--num_frames", f"{num_frames}", |
|
"--test_path", f"{test_path}", |
|
], |
|
check=True |
|
) |
|
|
|
output_videos = glob(os.path.join(f"{workdir}", "*.mp4")) |
|
return output_videos[0] |
|
|
|
except subprocess.CalledProcessError as e: |
|
raise gr.Error(f"Error during inference: {str(e)}") |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_video = gr.Video(label="Input Video") |
|
submit_btn = gr.Button("Submit") |
|
with gr.Column(): |
|
output_result = gr.Video(label="Result") |
|
|
|
gr.Examples( |
|
examples = video_examples, |
|
inputs = [input_video] |
|
) |
|
|
|
submit_btn.click( |
|
fn = generate, |
|
inputs = [input_video], |
|
outputs = [output_result] |
|
) |
|
|
|
demo.queue().launch(show_api=False, show_error=True) |
|
|