fffiloni commited on
Commit
bc16e8f
·
verified ·
1 Parent(s): f1609b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, subprocess
2
+ import uuid, tempfile
3
+ import gradio as gr
4
+ from huggingface_hub import snapshot_download
5
+
6
+ os.makedirs("pretrained", exist_ok=True)
7
+ snapshot_download(
8
+ repo_id = "jiawei011/L4GM",
9
+ local_dir = "./pretrained"
10
+ )
11
+
12
+ # Folder containing example images
13
+ examples_folder = "data_test"
14
+
15
+ # Retrieve all file paths in the folder
16
+ video_examples = [
17
+ os.path.join(examples_folder, file)
18
+ for file in os.listdir(examples_folder)
19
+ if os.path.isfile(os.path.join(examples_folder, file))
20
+ ]
21
+
22
+
23
+ def generate(input_video):
24
+
25
+ #--test_path data_test/otter-on-surfboard_fg.mp4
26
+ workdir = "results"
27
+ pretrained_model = "pretrained/recon.safetensors"
28
+ num_frames = 1
29
+ test_path = input_video
30
+
31
+ try:
32
+ # Run the inference command
33
+ subprocess.run(
34
+ [
35
+ "python", "infer_3d.py", "big",
36
+ f"workspace={workdir},
37
+ f"resume={pretrained_model}",
38
+ f"num_frames={num_frames}",
39
+ f"test_path={test_path}",
40
+ ],
41
+ check=True
42
+ )
43
+
44
+
45
+ # Retrieve the file name without the extension
46
+ #removed_bg_file_name = os.path.splitext(os.path.basename(removed_bg_path))[0]
47
+ output_videos = glob(os.path.join(f"{workdir}", "*.mp4"))
48
+ return output_videos
49
+ except subprocess.CalledProcessError as e:
50
+ return f"Error during inference: {str(e)}"
51
+
52
+ with gr.Blocks() as demo:
53
+ with gr.Column():
54
+ with gr.Row():
55
+ with gr.Column():
56
+ input_video = gr.Video(label="Input Video")
57
+ submit_btn = gr.Button("Submit")
58
+ with gr.Column():
59
+ output_result = gr.Video(label="Result")
60
+
61
+ gr.Examples(
62
+ examples = video_examples,
63
+ inputs = [input_video]
64
+ )
65
+
66
+ submit_btn.click(
67
+ fn = generate,
68
+ inputs = [input_video],
69
+ outputs = [output_result]
70
+ )
71
+
72
+ demo.queue().launch(show_api=False, show_error=True)