rexma commited on
Commit
2b243c6
·
1 Parent(s): 3d4535a

Work in process

Browse files
Files changed (1) hide show
  1. app.py +562 -101
app.py CHANGED
@@ -9,23 +9,240 @@
9
  # print("Command executed successfully.")
10
  # else:
11
  # print("Command failed with return code:", result.returncode)
 
12
  import gc
 
13
  import math
14
- # import multiprocessing as mp
15
- import torch.multiprocessing as mp
16
  import os
17
- from process_wrappers import clear_folder, draw_markers, sam_click_wrapper1, sam_stroke_process, tracking_objects_process
 
18
  os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
 
19
  import ffmpeg
 
 
 
 
 
 
 
 
 
 
20
  import cv2
 
21
 
 
 
22
 
23
- def clean():
24
- return ({}, {}, {}), None, None, 0, None, None, None, 0
 
 
 
 
 
 
 
 
 
25
 
26
- def show_res_by_slider(frame_per, click_stack):
27
- image_path = '/tmp/output_frames'
28
- output_combined_dir = '/tmp/output_combined'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)])
31
  if combined_frames:
@@ -37,7 +254,7 @@ def show_res_by_slider(frame_per, click_stack):
37
  total_frames_num = len(output_masked_frame_path)
38
  if total_frames_num == 0:
39
  print("No output results found")
40
- return None, None
41
  else:
42
  frame_num = math.floor(total_frames_num * frame_per / 100)
43
  if frame_per == 100:
@@ -46,11 +263,87 @@ def show_res_by_slider(frame_per, click_stack):
46
  print(f"{chosen_frame_path}")
47
  chosen_frame_show = cv2.imread(chosen_frame_path)
48
  chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
49
- points_dict, labels_dict, masks_dict = click_stack
50
  if frame_num in points_dict and frame_num in labels_dict:
51
  chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
52
  return chosen_frame_show, chosen_frame_show, frame_num
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def increment_ann_obj_id(ann_obj_id):
55
  ann_obj_id += 1
56
  return ann_obj_id
@@ -58,40 +351,141 @@ def increment_ann_obj_id(ann_obj_id):
58
  def drawing_board_get_input_first_frame(input_first_frame):
59
  return input_first_frame
60
 
61
- def sam_stroke_wrapper(click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id):
62
- queue = mp.Queue()
63
- p = mp.Process(target=sam_stroke_process, args=(queue, click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id))
64
- p.start()
65
- error, result = queue.get()
66
- p.join()
67
- if error:
68
- raise Exception(f"Error in sam_stroke_process: {error}")
69
- return result
70
-
71
- def tracking_objects_wrapper(click_stack, checkpoint, frame_num, input_video):
72
- queue = mp.Queue()
73
- p = mp.Process(target=tracking_objects_process, args=(queue, click_stack, checkpoint, frame_num, input_video))
74
- p.start()
75
- error, result = queue.get()
76
- p.join()
77
- if error:
78
- raise Exception(f"Error in sam_stroke_process: {error}")
79
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def seg_track_app():
82
- import gradio as gr
83
 
84
- def sam_click_wrapper(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
85
- return sam_click_wrapper1(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, [evt.index[0], evt.index[1]])
86
 
87
- def change_video(input_video):
88
- import gradio as gr
89
- if input_video is None:
90
- return 0, 0
91
- cap = cv2.VideoCapture(input_video)
92
- fps = cap.get(cv2.CAP_PROP_FPS)
93
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
94
- cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  scale_slider = gr.Slider.update(minimum=1.0,
96
  maximum=fps,
97
  step=1.0,
@@ -100,45 +494,107 @@ def seg_track_app():
100
  maximum= total_frames / fps,
101
  step=1.0/fps,
102
  value=0.0,)
103
- return scale_slider, frame_per
104
-
105
- def get_meta_from_video(input_video, scale_slider):
106
- import gradio as gr
107
- output_dir = '/tmp/output_frames'
108
- output_masks_dir = '/tmp/output_masks'
109
- output_combined_dir = '/tmp/`output_combined`'
110
- clear_folder(output_dir)
111
- clear_folder(output_masks_dir)
112
- clear_folder(output_combined_dir)
113
- if input_video is None:
114
- return ({}, {}, {}), None, None, 0, None, None, None, 0
115
- cap = cv2.VideoCapture(input_video)
116
- fps = cap.get(cv2.CAP_PROP_FPS)
117
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
118
- cap.release()
119
- frame_interval = max(1, int(fps // scale_slider))
120
- print(f"frame_interval: {frame_interval}")
121
- try:
122
- ffmpeg.input(input_video, hwaccel='cuda').output(
123
- os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
124
- vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
125
- ).run()
126
- except:
127
- print(f"ffmpeg cuda err")
128
- ffmpeg.input(input_video).output(
129
- os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
130
- vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
131
- ).run()
132
-
133
- first_frame_path = os.path.join(output_dir, '0000000.jpg')
134
- first_frame = cv2.imread(first_frame_path)
135
- first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  frame_per = gr.Slider.update(minimum= 0.0,
138
  maximum= total_frames / fps,
139
  step=frame_interval / fps,
140
  value=0.0,)
141
- return ({}, {}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  ##########################################################
144
  ###################### Front-end ########################
@@ -152,8 +608,9 @@ def seg_track_app():
152
  """
153
 
154
  app = gr.Blocks(css=css)
155
-
156
  with app:
 
 
157
  gr.Markdown(
158
  '''
159
  <div style="text-align:center; margin-bottom:20px;">
@@ -216,7 +673,7 @@ def seg_track_app():
216
  '''
217
  )
218
 
219
- click_stack = gr.State(({}, {}, {}))
220
  frame_num = gr.State(value=(int(0)))
221
  ann_obj_id = gr.State(value=(int(0)))
222
  last_draw = gr.State(None)
@@ -226,7 +683,7 @@ def seg_track_app():
226
  with gr.Row():
227
  tab_video_input = gr.Tab(label="Video input")
228
  with tab_video_input:
229
- input_video = gr.Video(label='Input video', elem_id="input_output_video")
230
  with gr.Row():
231
  checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
232
  scale_slider = gr.Slider(
@@ -251,7 +708,7 @@ def seg_track_app():
251
 
252
  tab_click = gr.Tab(label="Point Prompt")
253
  with tab_click:
254
- input_first_frame = gr.Image(label='Segment result of first frame',interactive=True, height=550)
255
  with gr.Row():
256
  point_mode = gr.Radio(
257
  choices=["Positive", "Negative"],
@@ -324,20 +781,22 @@ def seg_track_app():
324
 
325
  # listen to the preprocess button click to get the first frame of video with scaling
326
  preprocess_button.click(
327
- fn=get_meta_from_video,
328
  inputs=[
 
329
  input_video,
330
  scale_slider,
 
331
  ],
332
  outputs=[
333
- click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
334
  ]
335
  )
336
 
337
  frame_per.release(
338
- fn= show_res_by_slider,
339
  inputs=[
340
- frame_per, click_stack
341
  ],
342
  outputs=[
343
  input_first_frame, drawing_board, frame_num
@@ -346,9 +805,9 @@ def seg_track_app():
346
 
347
  # Interactively modify the mask acc click
348
  input_first_frame.select(
349
- fn=sam_click_wrapper,
350
  inputs=[
351
- checkpoint, frame_num, point_mode, click_stack, ann_obj_id
352
  ],
353
  outputs=[
354
  input_first_frame, drawing_board, click_stack
@@ -357,10 +816,9 @@ def seg_track_app():
357
 
358
  # Track object in video
359
  track_for_video.click(
360
- fn=tracking_objects_wrapper,
361
  inputs=[
362
- click_stack,
363
- checkpoint,
364
  frame_num,
365
  input_video,
366
  ],
@@ -374,17 +832,17 @@ def seg_track_app():
374
  )
375
 
376
  reset_button.click(
377
- fn=clean,
378
- inputs=[],
379
  outputs=[
380
  click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
381
  ]
382
  )
383
 
384
  new_object_button.click(
385
- fn=increment_ann_obj_id,
386
  inputs=[
387
- ann_obj_id
388
  ],
389
  outputs=[
390
  ann_obj_id
@@ -392,30 +850,33 @@ def seg_track_app():
392
  )
393
 
394
  tab_stroke.select(
395
- fn=drawing_board_get_input_first_frame,
396
- inputs=[input_first_frame,],
397
  outputs=[drawing_board,],
398
  )
399
 
400
  seg_acc_stroke.click(
401
- fn=sam_stroke_wrapper,
402
  inputs=[
403
- click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id
404
  ],
405
  outputs=[
406
- click_stack, input_first_frame, drawing_board, last_draw
407
  ]
408
  )
409
 
410
  input_video.change(
411
- fn=change_video,
412
- inputs=[input_video],
413
- outputs=[scale_slider, frame_per]
414
  )
415
 
416
  app.queue(concurrency_count=1)
417
- app.launch(debug=True, share=False)
418
 
419
  if __name__ == "__main__":
420
- mp.set_start_method('spawn', force=True)
 
 
 
421
  seg_track_app()
 
9
  # print("Command executed successfully.")
10
  # else:
11
  # print("Command failed with return code:", result.returncode)
12
+ import datetime
13
  import gc
14
+ import hashlib
15
  import math
16
+ import multiprocessing as mp
 
17
  import os
18
+ import threading
19
+ import time
20
  os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
21
+ import shutil
22
  import ffmpeg
23
+ from moviepy.editor import ImageSequenceClip
24
+ import zipfile
25
+ # import gradio as gr
26
+ import torch
27
+ import numpy as np
28
+ import matplotlib.pyplot as plt
29
+ from PIL import Image
30
+ from sam2.build_sam import build_sam2
31
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
32
+ from sam2.build_sam import build_sam2_video_predictor
33
  import cv2
34
+ import uuid
35
 
36
+ user_processes = {}
37
+ PROCESS_TIMEOUT = datetime.timedelta(minutes=4)
38
 
39
+ def reset(seg_tracker):
40
+ if seg_tracker is not None:
41
+ predictor, inference_state, image_predictor = seg_tracker
42
+ predictor.reset_state(inference_state)
43
+ del predictor
44
+ del inference_state
45
+ del image_predictor
46
+ del seg_tracker
47
+ gc.collect()
48
+ torch.cuda.empty_cache()
49
+ return None, ({}, {}), None, None, 0, None, None, None, 0
50
 
51
+ def extract_video_info(input_video):
52
+ if input_video is None:
53
+ return 4, 4, None, None, None, None, None
54
+ cap = cv2.VideoCapture(input_video)
55
+ fps = cap.get(cv2.CAP_PROP_FPS)
56
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
57
+ cap.release()
58
+ return fps, total_frames, None, None, None, None, None
59
+
60
+ def get_meta_from_video(session_id, input_video, scale_slider, checkpoint):
61
+ output_dir = f'/tmp/output_frames/{session_id}'
62
+ output_masks_dir = f'/tmp/output_masks/{session_id}'
63
+ output_combined_dir = f'/tmp/output_combined/{session_id}'
64
+ clear_folder(output_dir)
65
+ clear_folder(output_masks_dir)
66
+ clear_folder(output_combined_dir)
67
+ if input_video is None:
68
+ return None, ({}, {}), None, None, (4, 1, 4), None, None, None, 0
69
+ cap = cv2.VideoCapture(input_video)
70
+ fps = cap.get(cv2.CAP_PROP_FPS)
71
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
72
+ cap.release()
73
+ frame_interval = max(1, int(fps // scale_slider))
74
+ print(f"frame_interval: {frame_interval}")
75
+ try:
76
+ ffmpeg.input(input_video, hwaccel='cuda').output(
77
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
78
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
79
+ ).run()
80
+ except:
81
+ print(f"ffmpeg cuda err")
82
+ ffmpeg.input(input_video).output(
83
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
84
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
85
+ ).run()
86
+
87
+ first_frame_path = os.path.join(output_dir, '0000000.jpg')
88
+ first_frame = cv2.imread(first_frame_path)
89
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
90
+
91
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
92
+ if torch.cuda.get_device_properties(0).major >= 8:
93
+ torch.backends.cuda.matmul.allow_tf32 = True
94
+ torch.backends.cudnn.allow_tf32 = True
95
+
96
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_tiny.pt"
97
+ model_cfg = "sam2_hiera_t.yaml"
98
+ if checkpoint == "samll":
99
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_small.pt"
100
+ model_cfg = "sam2_hiera_s.yaml"
101
+ elif checkpoint == "base-plus":
102
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_base_plus.pt"
103
+ model_cfg = "sam2_hiera_b+.yaml"
104
+ elif checkpoint == "large":
105
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt"
106
+ model_cfg = "sam2_hiera_l.yaml"
107
+
108
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
109
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
110
+ image_predictor = SAM2ImagePredictor(sam2_model)
111
+ inference_state = predictor.init_state(video_path=output_dir)
112
+ predictor.reset_state(inference_state)
113
+ return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, (fps, frame_interval, total_frames), None, None, None, 0
114
+
115
+ def mask2bbox(mask):
116
+ if len(np.where(mask > 0)[0]) == 0:
117
+ print(f'not mask')
118
+ return np.array([0, 0, 0, 0]).astype(np.int64), False
119
+ x_ = np.sum(mask, axis=0)
120
+ y_ = np.sum(mask, axis=1)
121
+ x0 = np.min(np.nonzero(x_)[0])
122
+ x1 = np.max(np.nonzero(x_)[0])
123
+ y0 = np.min(np.nonzero(y_)[0])
124
+ y1 = np.max(np.nonzero(y_)[0])
125
+ return np.array([x0, y0, x1, y1]).astype(np.int64), True
126
+
127
+ def sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id):
128
+ predictor, inference_state, image_predictor = seg_tracker
129
+ image_path = f'/tmp/output_frames/{session_id}/{frame_num:07d}.jpg'
130
+ image = cv2.imread(image_path)
131
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
+ display_image = drawing_board["image"]
133
+ image_predictor.set_image(image)
134
+ input_mask = drawing_board["mask"]
135
+ input_mask[input_mask != 0] = 255
136
+ if last_draw is not None:
137
+ diff_mask = cv2.absdiff(input_mask, last_draw)
138
+ input_mask = diff_mask
139
+ bbox, hasMask = mask2bbox(input_mask[:, :, 0])
140
+ if not hasMask :
141
+ return seg_tracker, display_image, display_image, None
142
+ masks, scores, logits = image_predictor.predict( point_coords=None, point_labels=None, box=bbox[None, :], multimask_output=False,)
143
+ mask = masks > 0.0
144
+ masked_frame = show_mask(mask, display_image, ann_obj_id)
145
+ masked_with_rect = draw_rect(masked_frame, bbox, ann_obj_id)
146
+ frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=frame_num, obj_id=ann_obj_id, mask=mask[0])
147
+ last_draw = drawing_board["mask"]
148
+ return seg_tracker, masked_with_rect, masked_with_rect, last_draw
149
+
150
+ def draw_rect(image, bbox, obj_id):
151
+ cmap = plt.get_cmap("tab10")
152
+ color = np.array(cmap(obj_id)[:3])
153
+ rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
154
+ inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
155
+ x0, y0, x1, y1 = bbox
156
+ image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), rgb_color, thickness=2)
157
+ return image_with_rect
158
+
159
+ def sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point):
160
+ points_dict, labels_dict = click_stack
161
+ predictor, inference_state, image_predictor = seg_tracker
162
+ ann_frame_idx = frame_num # the frame index we interact with
163
+ print(f'ann_frame_idx: {ann_frame_idx}')
164
+ if point_mode == "Positive":
165
+ label = np.array([1], np.int32)
166
+ else:
167
+ label = np.array([0], np.int32)
168
+
169
+ if ann_frame_idx not in points_dict:
170
+ points_dict[ann_frame_idx] = {}
171
+ if ann_frame_idx not in labels_dict:
172
+ labels_dict[ann_frame_idx] = {}
173
+
174
+ if ann_obj_id not in points_dict[ann_frame_idx]:
175
+ points_dict[ann_frame_idx][ann_obj_id] = np.empty((0, 2), dtype=np.float32)
176
+ if ann_obj_id not in labels_dict[ann_frame_idx]:
177
+ labels_dict[ann_frame_idx][ann_obj_id] = np.empty((0,), dtype=np.int32)
178
+
179
+ points_dict[ann_frame_idx][ann_obj_id] = np.append(points_dict[ann_frame_idx][ann_obj_id], point, axis=0)
180
+ labels_dict[ann_frame_idx][ann_obj_id] = np.append(labels_dict[ann_frame_idx][ann_obj_id], label, axis=0)
181
+
182
+ click_stack = (points_dict, labels_dict)
183
+
184
+ frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points(
185
+ inference_state=inference_state,
186
+ frame_idx=ann_frame_idx,
187
+ obj_id=ann_obj_id,
188
+ points=points_dict[ann_frame_idx][ann_obj_id],
189
+ labels=labels_dict[ann_frame_idx][ann_obj_id],
190
+ )
191
+
192
+ image_path = f'/tmp/output_frames/{session_id}/{ann_frame_idx:07d}.jpg'
193
+ image = cv2.imread(image_path)
194
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
195
+
196
+ masked_frame = image.copy()
197
+ for i, obj_id in enumerate(out_obj_ids):
198
+ mask = (out_mask_logits[i] > 0.0).cpu().numpy()
199
+ masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
200
+ masked_frame_with_markers = draw_markers(masked_frame, points_dict[ann_frame_idx], labels_dict[ann_frame_idx])
201
+
202
+ return seg_tracker, masked_frame_with_markers, masked_frame_with_markers, click_stack
203
+
204
+ def draw_markers(image, points_dict, labels_dict):
205
+ cmap = plt.get_cmap("tab10")
206
+ image_h, image_w = image.shape[:2]
207
+ marker_size = max(1, int(min(image_h, image_w) * 0.05))
208
+
209
+ for obj_id in points_dict:
210
+ color = np.array(cmap(obj_id)[:3])
211
+ rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
212
+ inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
213
+ for point, label in zip(points_dict[obj_id], labels_dict[obj_id]):
214
+ x, y = int(point[0]), int(point[1])
215
+ if label == 1:
216
+ cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_CROSS, markerSize=marker_size, thickness=2)
217
+ else:
218
+ cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_TILTED_CROSS, markerSize=int(marker_size / np.sqrt(2)), thickness=2)
219
+
220
+ return image
221
+
222
+ def show_mask(mask, image=None, obj_id=None):
223
+ cmap = plt.get_cmap("tab10")
224
+ cmap_idx = 0 if obj_id is None else obj_id
225
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
226
+
227
+ h, w = mask.shape[-2:]
228
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
229
+ mask_image = (mask_image * 255).astype(np.uint8)
230
+ if image is not None:
231
+ image_h, image_w = image.shape[:2]
232
+ if (image_h, image_w) != (h, w):
233
+ raise ValueError(f"Image dimensions ({image_h}, {image_w}) and mask dimensions ({h}, {w}) do not match")
234
+ colored_mask = np.zeros_like(image, dtype=np.uint8)
235
+ for c in range(3):
236
+ colored_mask[..., c] = mask_image[..., c]
237
+ alpha_mask = mask_image[..., 3] / 255.0
238
+ for c in range(3):
239
+ image[..., c] = np.where(alpha_mask > 0, (1 - alpha_mask) * image[..., c] + alpha_mask * colored_mask[..., c], image[..., c])
240
+ return image
241
+ return mask_image
242
+
243
+ def show_res_by_slider(session_id, frame_per, click_stack):
244
+ image_path = f'/tmp/output_frames/{session_id}'
245
+ output_combined_dir = f'/tmp/output_combined/{session_id}'
246
 
247
  combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)])
248
  if combined_frames:
 
254
  total_frames_num = len(output_masked_frame_path)
255
  if total_frames_num == 0:
256
  print("No output results found")
257
+ return None, None, 0
258
  else:
259
  frame_num = math.floor(total_frames_num * frame_per / 100)
260
  if frame_per == 100:
 
263
  print(f"{chosen_frame_path}")
264
  chosen_frame_show = cv2.imread(chosen_frame_path)
265
  chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
266
+ points_dict, labels_dict = click_stack
267
  if frame_num in points_dict and frame_num in labels_dict:
268
  chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
269
  return chosen_frame_show, chosen_frame_show, frame_num
270
 
271
+ def clear_folder(folder_path):
272
+ if os.path.exists(folder_path):
273
+ shutil.rmtree(folder_path)
274
+ os.makedirs(folder_path)
275
+
276
+ def zip_folder(folder_path, output_zip_path):
277
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_STORED) as zipf:
278
+ for root, _, files in os.walk(folder_path):
279
+ for file in files:
280
+ file_path = os.path.join(root, file)
281
+ zipf.write(file_path, os.path.relpath(file_path, folder_path))
282
+
283
+ def tracking_objects(session_id, seg_tracker, frame_num, input_video):
284
+ output_dir = f'/tmp/output_frames/{session_id}'
285
+ output_masks_dir = f'/tmp/output_masks/{session_id}'
286
+ output_combined_dir = f'/tmp/output_combined/{session_id}'
287
+ output_files_dir = f'/tmp/output_files/{session_id}'
288
+ output_video_path = f'{output_files_dir}/output_video.mp4'
289
+ output_zip_path = f'{output_files_dir}/output_masks.zip'
290
+ clear_folder(output_masks_dir)
291
+ clear_folder(output_combined_dir)
292
+ clear_folder(output_files_dir)
293
+ video_segments = {}
294
+ predictor, inference_state, image_predictor = seg_tracker
295
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
296
+ video_segments[out_frame_idx] = {
297
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
298
+ for i, out_obj_id in enumerate(out_obj_ids)
299
+ }
300
+ frame_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.jpg')])
301
+ # for frame_idx in sorted(video_segments.keys()):
302
+ for frame_file in frame_files:
303
+ frame_idx = int(os.path.splitext(frame_file)[0])
304
+ frame_path = os.path.join(output_dir, frame_file)
305
+ image = cv2.imread(frame_path)
306
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
307
+ masked_frame = image.copy()
308
+ if frame_idx in video_segments:
309
+ for obj_id, mask in video_segments[frame_idx].items():
310
+ masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
311
+ mask_output_path = os.path.join(output_masks_dir, f'{obj_id}_{frame_idx:07d}.png')
312
+ cv2.imwrite(mask_output_path, show_mask(mask))
313
+ combined_output_path = os.path.join(output_combined_dir, f'{frame_idx:07d}.png')
314
+ combined_image_bgr = cv2.cvtColor(masked_frame, cv2.COLOR_RGB2BGR)
315
+ cv2.imwrite(combined_output_path, combined_image_bgr)
316
+ if frame_idx == frame_num:
317
+ final_masked_frame = masked_frame
318
+
319
+ cap = cv2.VideoCapture(input_video)
320
+ fps = cap.get(cv2.CAP_PROP_FPS)
321
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
322
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
323
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
324
+ cap.release()
325
+ # output_frames = int(total_frames * scale_slider)
326
+ output_frames = len([name for name in os.listdir(output_combined_dir) if os.path.isfile(os.path.join(output_combined_dir, name)) and name.endswith('.png')])
327
+ out_fps = fps * output_frames / total_frames
328
+
329
+ # ffmpeg.input(os.path.join(output_combined_dir, '%07d.png'), framerate=out_fps).output(output_video_path, vcodec='h264_nvenc', pix_fmt='yuv420p').run()
330
+
331
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
332
+ # out = cv2.VideoWriter(output_video_path, fourcc, out_fps, (frame_width, frame_height))
333
+ # for i in range(output_frames):
334
+ # frame_path = os.path.join(output_combined_dir, f'{i:07d}.png')
335
+ # frame = cv2.imread(frame_path)
336
+ # out.write(frame)
337
+ # out.release()
338
+
339
+ image_files = [os.path.join(output_combined_dir, f'{i:07d}.png') for i in range(output_frames)]
340
+ clip = ImageSequenceClip(image_files, fps=out_fps)
341
+ clip.write_videofile(output_video_path, codec="libx264", fps=out_fps)
342
+
343
+ zip_folder(output_masks_dir, output_zip_path)
344
+ print("done")
345
+ return final_masked_frame, final_masked_frame, output_video_path, output_video_path, output_zip_path
346
+
347
  def increment_ann_obj_id(ann_obj_id):
348
  ann_obj_id += 1
349
  return ann_obj_id
 
351
  def drawing_board_get_input_first_frame(input_first_frame):
352
  return input_first_frame
353
 
354
+ def process_video(queue, result_queue, session_id):
355
+ seg_tracker = None
356
+ click_stack = ({}, {})
357
+ frame_num = int(0)
358
+ ann_obj_id =int(0)
359
+ last_draw = None
360
+ while True:
361
+ task = queue.get()
362
+ if task["command"] == "exit":
363
+ print(f"Process for {session_id} exiting.")
364
+ break
365
+ elif task["command"] == "extract_video_info":
366
+ input_video = task["input_video"]
367
+ fps, total_frames, input_first_frame, drawing_board, output_video, output_mp4, output_mask = extract_video_info(input_video)
368
+ result_queue.put({"fps": fps, "total_frames": total_frames, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask})
369
+ elif task["command"] == "get_meta_from_video":
370
+ input_video = task["input_video"]
371
+ scale_slider = task["scale_slider"]
372
+ checkpoint = task["checkpoint"]
373
+ seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id = get_meta_from_video(session_id, input_video, scale_slider, checkpoint)
374
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id})
375
+ elif task["command"] == "sam_stroke":
376
+ drawing_board = task["drawing_board"]
377
+ last_draw = task["last_draw"]
378
+ frame_num = task["frame_num"]
379
+ ann_obj_id = task["ann_obj_id"]
380
+ seg_tracker, input_first_frame, drawing_board, last_draw = sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id)
381
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw})
382
+ elif task["command"] == "sam_click":
383
+ frame_num = task["frame_num"]
384
+ point_mode = task["point_mode"]
385
+ click_stack = task["click_stack"]
386
+ ann_obj_id = task["ann_obj_id"]
387
+ point = task["point"]
388
+ seg_tracker, input_first_frame, drawing_board, last_draw = sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point)
389
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw})
390
+ elif task["command"] == "increment_ann_obj_id":
391
+ ann_obj_id = task["ann_obj_id"]
392
+ ann_obj_id = increment_ann_obj_id(ann_obj_id)
393
+ result_queue.put({"ann_obj_id": ann_obj_id})
394
+ elif task["command"] == "drawing_board_get_input_first_frame":
395
+ input_first_frame = task["input_first_frame"]
396
+ input_first_frame = drawing_board_get_input_first_frame(input_first_frame)
397
+ result_queue.put({"input_first_frame": input_first_frame})
398
+ elif task["command"] == "reset":
399
+ seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id = reset(seg_tracker)
400
+ result_queue.put({"click_stack": click_stack, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id})
401
+ elif task["command"] == "show_res_by_slider":
402
+ frame_per = task["frame_per"]
403
+ click_stack = task["click_stack"]
404
+ input_first_frame, drawing_board, frame_num = show_res_by_slider(session_id, frame_per, click_stack)
405
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_num": frame_num})
406
+ elif task["command"] == "tracking_objects":
407
+ frame_num = task["frame_num"]
408
+ input_video = task["input_video"]
409
+ input_first_frame, drawing_board, output_video, output_mp4, output_mask = tracking_objects(session_id, seg_tracker, frame_num, input_video)
410
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask})
411
+ else:
412
+ print(f"Unknown command {task['command']} for {session_id}")
413
+ result_queue.put("Unknown command")
414
+
415
+ def start_process(session_id):
416
+ if session_id not in user_processes:
417
+ queue = mp.Queue()
418
+ result_queue = mp.Queue()
419
+ process = mp.Process(target=process_video, args=(queue, result_queue, session_id))
420
+ process.start()
421
+ user_processes[session_id] = {
422
+ "process": process,
423
+ "queue": queue,
424
+ "result_queue": result_queue,
425
+ "last_active": datetime.datetime.now()
426
+ }
427
+ else:
428
+ user_processes[session_id]["last_active"] = datetime.datetime.now()
429
+ return user_processes[session_id]["queue"]
430
+
431
+ # def clean_up_processes(session_id, init_clean = False):
432
+ # now = datetime.datetime.now()
433
+ # to_remove = []
434
+ # for s_id, process_info in user_processes.items():
435
+ # if (now - process_info["last_active"] > PROCESS_TIMEOUT) or (s_id == session_id and init_clean):
436
+ # process_info["queue"].put({"command": "exit"})
437
+ # process_info["process"].terminate()
438
+ # process_info["process"].join()
439
+ # to_remove.append(s_id)
440
+ # for s_id in to_remove:
441
+ # del user_processes[s_id]
442
+ # print(f"Cleaned up process for session {s_id}.")
443
+
444
+ def monitor_and_cleanup_processes():
445
+ while True:
446
+ now = datetime.datetime.now()
447
+ to_remove = []
448
+ for session_id, process_info in user_processes.items():
449
+ if now - process_info["last_active"] > PROCESS_TIMEOUT:
450
+ process_info["queue"].put({"command": "exit"})
451
+ process_info["process"].terminate()
452
+ process_info["process"].join()
453
+ to_remove.append(session_id)
454
+ for session_id in to_remove:
455
+ del user_processes[session_id]
456
+ print(f"Automatically cleaned up process for session {session_id}.")
457
+ time.sleep(10)
458
 
459
  def seg_track_app():
 
460
 
461
+ import gradio as gr
 
462
 
463
+ def extract_session_id_from_request(request: gr.Request):
464
+ session_id = hashlib.sha256(f'{request.client.host}:{request.client.port}'.encode('utf-8')).hexdigest()
465
+ # cookies = request.kwargs["headers"].get('cookie', '')
466
+ # session_id = None
467
+ # if '_gid=' in cookies:
468
+ # session_id = cookies.split('_gid=')[1].split(';')[0]
469
+ # else:
470
+ # session_id = str(uuid.uuid4())
471
+ print(f"session_id {session_id}")
472
+ return session_id
473
+
474
+ def handle_extract_video_info(session_id, input_video):
475
+ # clean_up_processes(session_id, init_clean=True)
476
+ if input_video == None:
477
+ return 0, 0, None, None, None, None, None
478
+ queue = start_process(session_id)
479
+ result_queue = user_processes[session_id]["result_queue"]
480
+ queue.put({"command": "extract_video_info", "input_video": input_video})
481
+ result = result_queue.get()
482
+ fps = result.get("fps")
483
+ total_frames = result.get("total_frames")
484
+ input_first_frame = result.get("input_first_frame")
485
+ drawing_board = result.get("drawing_board")
486
+ output_video = result.get("output_video")
487
+ output_mp4 = result.get("output_mp4")
488
+ output_mask = result.get("output_mask")
489
  scale_slider = gr.Slider.update(minimum=1.0,
490
  maximum=fps,
491
  step=1.0,
 
494
  maximum= total_frames / fps,
495
  step=1.0/fps,
496
  value=0.0,)
497
+ return scale_slider, frame_per, input_first_frame, drawing_board, output_video, output_mp4, output_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
+ def handle_get_meta_from_video(session_id, input_video, scale_slider, checkpoint):
500
+ # clean_up_processes(session_id)
501
+ queue = start_process(session_id)
502
+ result_queue = user_processes[session_id]["result_queue"]
503
+ queue.put({"command": "get_meta_from_video", "input_video": input_video, "scale_slider": scale_slider, "checkpoint": checkpoint})
504
+ result = result_queue.get()
505
+ input_first_frame = result.get("input_first_frame")
506
+ drawing_board = result.get("drawing_board")
507
+ (fps, frame_interval, total_frames) = result.get("frame_per")
508
+ output_video = result.get("output_video")
509
+ output_mp4 = result.get("output_mp4")
510
+ output_mask = result.get("output_mask")
511
+ ann_obj_id = result.get("ann_obj_id")
512
  frame_per = gr.Slider.update(minimum= 0.0,
513
  maximum= total_frames / fps,
514
  step=frame_interval / fps,
515
  value=0.0,)
516
+ return input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
517
+
518
+ def handle_sam_stroke(session_id, drawing_board, last_draw, frame_num, ann_obj_id):
519
+ # clean_up_processes(session_id)
520
+ queue = start_process(session_id)
521
+ result_queue = user_processes[session_id]["result_queue"]
522
+ queue.put({"command": "sam_stroke", "drawing_board": drawing_board, "last_draw": last_draw, "frame_num": frame_num, "ann_obj_id": ann_obj_id})
523
+ result = result_queue.get()
524
+ input_first_frame = result.get("input_first_frame")
525
+ drawing_board = result.get("drawing_board")
526
+ last_draw = result.get("last_draw")
527
+ return input_first_frame, drawing_board, last_draw
528
+
529
+ def handle_sam_click(session_id, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
530
+ # clean_up_processes(session_id)
531
+ queue = start_process(session_id)
532
+ result_queue = user_processes[session_id]["result_queue"]
533
+ point = np.array([[evt.index[0], evt.index[1]]], dtype=np.float32)
534
+ queue.put({"command": "sam_click", "frame_num": frame_num, "point_mode": point_mode, "click_stack": click_stack, "ann_obj_id": ann_obj_id, "point": point})
535
+ result = result_queue.get()
536
+ input_first_frame = result.get("input_first_frame")
537
+ drawing_board = result.get("drawing_board")
538
+ last_draw = result.get("last_draw")
539
+ return input_first_frame, drawing_board, last_draw
540
+
541
+ def handle_increment_ann_obj_id(session_id, ann_obj_id):
542
+ # clean_up_processes(session_id)
543
+ queue = start_process(session_id)
544
+ result_queue = user_processes[session_id]["result_queue"]
545
+ queue.put({"command": "increment_ann_obj_id", "ann_obj_id": ann_obj_id})
546
+ result = result_queue.get()
547
+ ann_obj_id = result.get("ann_obj_id")
548
+ return ann_obj_id
549
+
550
+ def handle_drawing_board_get_input_first_frame(session_id, input_first_frame):
551
+ # clean_up_processes(session_id)
552
+ queue = start_process(session_id)
553
+ result_queue = user_processes[session_id]["result_queue"]
554
+ queue.put({"command": "drawing_board_get_input_first_frame", "input_first_frame": input_first_frame})
555
+ result = result_queue.get()
556
+ input_first_frame = result.get("input_first_frame")
557
+ return input_first_frame
558
+
559
+ def handle_reset(session_id):
560
+ # clean_up_processes(session_id)
561
+ queue = start_process(session_id)
562
+ result_queue = user_processes[session_id]["result_queue"]
563
+ queue.put({"command": "reset"})
564
+ result = result_queue.get()
565
+ click_stack = result.get("click_stack")
566
+ input_first_frame = result.get("input_first_frame")
567
+ drawing_board = result.get("drawing_board")
568
+ frame_per = result.get("frame_per")
569
+ output_video = result.get("output_video")
570
+ output_mp4 = result.get("output_mp4")
571
+ output_mask = result.get("output_mask")
572
+ ann_obj_id = result.get("ann_obj_id")
573
+ return click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
574
+
575
+ def handle_show_res_by_slider(session_id, frame_per, click_stack):
576
+ # clean_up_processes(session_id)
577
+ queue = start_process(session_id)
578
+ result_queue = user_processes[session_id]["result_queue"]
579
+ queue.put({"command": "show_res_by_slider", "frame_per": frame_per, "click_stack": click_stack})
580
+ result = result_queue.get()
581
+ input_first_frame = result.get("input_first_frame")
582
+ drawing_board = result.get("drawing_board")
583
+ frame_num = result.get("frame_num")
584
+ return input_first_frame, drawing_board, frame_num
585
+
586
+ def handle_tracking_objects(session_id, frame_num, input_video):
587
+ # clean_up_processes(session_id)
588
+ queue = start_process(session_id)
589
+ result_queue = user_processes[session_id]["result_queue"]
590
+ queue.put({"command": "tracking_objects", "frame_num": frame_num, "input_video": input_video})
591
+ result = result_queue.get()
592
+ input_first_frame = result.get("input_first_frame")
593
+ drawing_board = result.get("drawing_board")
594
+ output_video = result.get("output_video")
595
+ output_mp4 = result.get("output_mp4")
596
+ output_mask = result.get("output_mask")
597
+ return input_first_frame, drawing_board, output_video, output_mp4, output_mask
598
 
599
  ##########################################################
600
  ###################### Front-end ########################
 
608
  """
609
 
610
  app = gr.Blocks(css=css)
 
611
  with app:
612
+ session_id = gr.State()
613
+ app.load(extract_session_id_from_request, None, session_id)
614
  gr.Markdown(
615
  '''
616
  <div style="text-align:center; margin-bottom:20px;">
 
673
  '''
674
  )
675
 
676
+ click_stack = gr.State(({}, {}))
677
  frame_num = gr.State(value=(int(0)))
678
  ann_obj_id = gr.State(value=(int(0)))
679
  last_draw = gr.State(None)
 
683
  with gr.Row():
684
  tab_video_input = gr.Tab(label="Video input")
685
  with tab_video_input:
686
+ input_video = gr.Video(label='Input video', type=["mp4", "mov", "avi"], elem_id="input_output_video")
687
  with gr.Row():
688
  checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
689
  scale_slider = gr.Slider(
 
708
 
709
  tab_click = gr.Tab(label="Point Prompt")
710
  with tab_click:
711
+ input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
712
  with gr.Row():
713
  point_mode = gr.Radio(
714
  choices=["Positive", "Negative"],
 
781
 
782
  # listen to the preprocess button click to get the first frame of video with scaling
783
  preprocess_button.click(
784
+ fn=handle_get_meta_from_video,
785
  inputs=[
786
+ session_id,
787
  input_video,
788
  scale_slider,
789
+ checkpoint
790
  ],
791
  outputs=[
792
+ input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
793
  ]
794
  )
795
 
796
  frame_per.release(
797
+ fn=handle_show_res_by_slider,
798
  inputs=[
799
+ session_id, frame_per, click_stack
800
  ],
801
  outputs=[
802
  input_first_frame, drawing_board, frame_num
 
805
 
806
  # Interactively modify the mask acc click
807
  input_first_frame.select(
808
+ fn=handle_sam_click,
809
  inputs=[
810
+ session_id, frame_num, point_mode, click_stack, ann_obj_id
811
  ],
812
  outputs=[
813
  input_first_frame, drawing_board, click_stack
 
816
 
817
  # Track object in video
818
  track_for_video.click(
819
+ fn=handle_tracking_objects,
820
  inputs=[
821
+ session_id,
 
822
  frame_num,
823
  input_video,
824
  ],
 
832
  )
833
 
834
  reset_button.click(
835
+ fn=handle_reset,
836
+ inputs=[session_id],
837
  outputs=[
838
  click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
839
  ]
840
  )
841
 
842
  new_object_button.click(
843
+ fn=handle_increment_ann_obj_id,
844
  inputs=[
845
+ session_id, ann_obj_id
846
  ],
847
  outputs=[
848
  ann_obj_id
 
850
  )
851
 
852
  tab_stroke.select(
853
+ fn=handle_drawing_board_get_input_first_frame,
854
+ inputs=[session_id, input_first_frame],
855
  outputs=[drawing_board,],
856
  )
857
 
858
  seg_acc_stroke.click(
859
+ fn=handle_sam_stroke,
860
  inputs=[
861
+ session_id, drawing_board, last_draw, frame_num, ann_obj_id
862
  ],
863
  outputs=[
864
+ input_first_frame, drawing_board, last_draw
865
  ]
866
  )
867
 
868
  input_video.change(
869
+ fn=handle_extract_video_info,
870
+ inputs=[session_id, input_video],
871
+ outputs=[scale_slider, frame_per, input_first_frame, drawing_board, output_video, output_mp4, output_mask]
872
  )
873
 
874
  app.queue(concurrency_count=1)
875
+ app.launch(debug=True, enable_queue=True, share=False)
876
 
877
  if __name__ == "__main__":
878
+ mp.set_start_method("spawn")
879
+ monitor_thread = threading.Thread(target=monitor_and_cleanup_processes)
880
+ monitor_thread.daemon = True
881
+ monitor_thread.start()
882
  seg_track_app()