import datetime |
import gc |
import hashlib |
import math |
import multiprocessing as mp |
import os |
import threading |
import time |
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" |
import shutil |
import ffmpeg |
from moviepy.editor import ImageSequenceClip |
import zipfile |
import torch |
import numpy as np |
import matplotlib.pyplot as plt |
from PIL import Image |
from sam2.build_sam import build_sam2 |
from sam2.sam2_image_predictor import SAM2ImagePredictor |
from sam2.build_sam import build_sam2_video_predictor |
import cv2 |
import uuid |
user_processes = {} |
PROCESS_TIMEOUT = datetime.timedelta(minutes=4) |
def reset(seg_tracker): |
if seg_tracker is not None: |
predictor, inference_state, image_predictor = seg_tracker |
predictor.reset_state(inference_state) |
del predictor |
del inference_state |
del image_predictor |
del seg_tracker |
gc.collect() |
torch.cuda.empty_cache() |
return None, ({}, {}), None, None, 0, None, None, None, 0 |
def extract_video_info(input_video): |
if input_video is None: |
return 4, 4, None, None, None, None, None |
cap = cv2.VideoCapture(input_video) |
fps = cap.get(cv2.CAP_PROP_FPS) |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
cap.release() |
return fps, total_frames, None, None, None, None, None |
def get_meta_from_video(session_id, input_video, scale_slider, checkpoint): |
output_dir = f'/tmp/output_frames/{session_id}' |
output_masks_dir = f'/tmp/output_masks/{session_id}' |
output_combined_dir = f'/tmp/output_combined/{session_id}' |
clear_folder(output_dir) |
clear_folder(output_masks_dir) |
clear_folder(output_combined_dir) |
if input_video is None: |
return None, ({}, {}), None, None, (4, 1, 4), None, None, None, 0 |
cap = cv2.VideoCapture(input_video) |
fps = cap.get(cv2.CAP_PROP_FPS) |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
cap.release() |
frame_interval = max(1, int(fps // scale_slider)) |
print(f"frame_interval: {frame_interval}") |
try: |
ffmpeg.input(input_video, hwaccel='cuda').output( |
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, |
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' |
).run() |
except: |
print(f"ffmpeg cuda err") |
ffmpeg.input(input_video).output( |
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, |
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' |
).run() |
first_frame_path = os.path.join(output_dir, '0000000.jpg') |
first_frame = cv2.imread(first_frame_path) |
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() |
if torch.cuda.get_device_properties(0).major >= 8: |
torch.backends.cuda.matmul.allow_tf32 = True |
torch.backends.cudnn.allow_tf32 = True |
sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_tiny.pt" |
model_cfg = "sam2_hiera_t.yaml" |
if checkpoint == "samll": |
sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_small.pt" |
model_cfg = "sam2_hiera_s.yaml" |
elif checkpoint == "base-plus": |
sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_base_plus.pt" |
model_cfg = "sam2_hiera_b+.yaml" |
elif checkpoint == "large": |
sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt" |
model_cfg = "sam2_hiera_l.yaml" |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda") |
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") |
image_predictor = SAM2ImagePredictor(sam2_model) |
inference_state = predictor.init_state(video_path=output_dir) |
predictor.reset_state(inference_state) |
return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, (fps, frame_interval, total_frames), None, None, None, 0 |
def mask2bbox(mask): |
if len(np.where(mask > 0)[0]) == 0: |
print(f'not mask') |
return np.array([0, 0, 0, 0]).astype(np.int64), False |
x_ = np.sum(mask, axis=0) |
y_ = np.sum(mask, axis=1) |
x0 = np.min(np.nonzero(x_)[0]) |
x1 = np.max(np.nonzero(x_)[0]) |
y0 = np.min(np.nonzero(y_)[0]) |
y1 = np.max(np.nonzero(y_)[0]) |
return np.array([x0, y0, x1, y1]).astype(np.int64), True |
def sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id): |
predictor, inference_state, image_predictor = seg_tracker |
image_path = f'/tmp/output_frames/{session_id}/{frame_num:07d}.jpg' |
image = cv2.imread(image_path) |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
display_image = drawing_board["image"] |
image_predictor.set_image(image) |
input_mask = drawing_board["mask"] |
input_mask[input_mask != 0] = 255 |
if last_draw is not None: |
diff_mask = cv2.absdiff(input_mask, last_draw) |
input_mask = diff_mask |
bbox, hasMask = mask2bbox(input_mask[:, :, 0]) |
if not hasMask : |
return seg_tracker, display_image, display_image, None |
masks, scores, logits = image_predictor.predict( point_coords=None, point_labels=None, box=bbox[None, :], multimask_output=False,) |
mask = masks > 0.0 |
masked_frame = show_mask(mask, display_image, ann_obj_id) |
masked_with_rect = draw_rect(masked_frame, bbox, ann_obj_id) |
frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=frame_num, obj_id=ann_obj_id, mask=mask[0]) |
last_draw = drawing_board["mask"] |
return seg_tracker, masked_with_rect, masked_with_rect, last_draw |
def draw_rect(image, bbox, obj_id): |
cmap = plt.get_cmap("tab10") |
color = np.array(cmap(obj_id)[:3]) |
rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8))) |
inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8))) |
x0, y0, x1, y1 = bbox |
image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), rgb_color, thickness=2) |
return image_with_rect |
def sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point): |
points_dict, labels_dict = click_stack |
predictor, inference_state, image_predictor = seg_tracker |
ann_frame_idx = frame_num |
print(f'ann_frame_idx: {ann_frame_idx}') |
if point_mode == "Positive": |
label = np.array([1], np.int32) |
else: |
label = np.array([0], np.int32) |
if ann_frame_idx not in points_dict: |
points_dict[ann_frame_idx] = {} |
if ann_frame_idx not in labels_dict: |
labels_dict[ann_frame_idx] = {} |
if ann_obj_id not in points_dict[ann_frame_idx]: |
points_dict[ann_frame_idx][ann_obj_id] = np.empty((0, 2), dtype=np.float32) |
if ann_obj_id not in labels_dict[ann_frame_idx]: |
labels_dict[ann_frame_idx][ann_obj_id] = np.empty((0,), dtype=np.int32) |
points_dict[ann_frame_idx][ann_obj_id] = np.append(points_dict[ann_frame_idx][ann_obj_id], point, axis=0) |
labels_dict[ann_frame_idx][ann_obj_id] = np.append(labels_dict[ann_frame_idx][ann_obj_id], label, axis=0) |
click_stack = (points_dict, labels_dict) |
frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points( |
inference_state=inference_state, |
frame_idx=ann_frame_idx, |
obj_id=ann_obj_id, |
points=points_dict[ann_frame_idx][ann_obj_id], |
labels=labels_dict[ann_frame_idx][ann_obj_id], |
) |
image_path = f'/tmp/output_frames/{session_id}/{ann_frame_idx:07d}.jpg' |
image = cv2.imread(image_path) |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
masked_frame = image.copy() |
for i, obj_id in enumerate(out_obj_ids): |
mask = (out_mask_logits[i] > 0.0).cpu().numpy() |
masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id) |
masked_frame_with_markers = draw_markers(masked_frame, points_dict[ann_frame_idx], labels_dict[ann_frame_idx]) |
return seg_tracker, masked_frame_with_markers, masked_frame_with_markers, click_stack |
def draw_markers(image, points_dict, labels_dict): |
cmap = plt.get_cmap("tab10") |
image_h, image_w = image.shape[:2] |
marker_size = max(1, int(min(image_h, image_w) * 0.05)) |
for obj_id in points_dict: |
color = np.array(cmap(obj_id)[:3]) |
rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8))) |
inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8))) |
for point, label in zip(points_dict[obj_id], labels_dict[obj_id]): |
x, y = int(point[0]), int(point[1]) |
if label == 1: |
cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_CROSS, markerSize=marker_size, thickness=2) |
else: |
cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_TILTED_CROSS, markerSize=int(marker_size / np.sqrt(2)), thickness=2) |
return image |
def show_mask(mask, image=None, obj_id=None): |
cmap = plt.get_cmap("tab10") |
cmap_idx = 0 if obj_id is None else obj_id |
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
h, w = mask.shape[-2:] |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
mask_image = (mask_image * 255).astype(np.uint8) |
if image is not None: |
image_h, image_w = image.shape[:2] |
if (image_h, image_w) != (h, w): |
raise ValueError(f"Image dimensions ({image_h}, {image_w}) and mask dimensions ({h}, {w}) do not match") |
colored_mask = np.zeros_like(image, dtype=np.uint8) |
for c in range(3): |
colored_mask[..., c] = mask_image[..., c] |
alpha_mask = mask_image[..., 3] / 255.0 |
for c in range(3): |
image[..., c] = np.where(alpha_mask > 0, (1 - alpha_mask) * image[..., c] + alpha_mask * colored_mask[..., c], image[..., c]) |
return image |
return mask_image |
def show_res_by_slider(session_id, frame_per, click_stack): |
image_path = f'/tmp/output_frames/{session_id}' |
output_combined_dir = f'/tmp/output_combined/{session_id}' |
combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)]) |
if combined_frames: |
output_masked_frame_path = combined_frames |
else: |
original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)]) |
output_masked_frame_path = original_frames |
total_frames_num = len(output_masked_frame_path) |
if total_frames_num == 0: |
print("No output results found") |
return None, None, 0 |
else: |
frame_num = math.floor(total_frames_num * frame_per / 100) |
if frame_per == 100: |
frame_num = frame_num - 1 |
chosen_frame_path = output_masked_frame_path[frame_num] |
print(f"{chosen_frame_path}") |
chosen_frame_show = cv2.imread(chosen_frame_path) |
chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB) |
points_dict, labels_dict = click_stack |
if frame_num in points_dict and frame_num in labels_dict: |
chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num]) |
return chosen_frame_show, chosen_frame_show, frame_num |
def clear_folder(folder_path): |
if os.path.exists(folder_path): |
shutil.rmtree(folder_path) |
os.makedirs(folder_path) |
def zip_folder(folder_path, output_zip_path): |
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_STORED) as zipf: |
for root, _, files in os.walk(folder_path): |
for file in files: |
file_path = os.path.join(root, file) |
zipf.write(file_path, os.path.relpath(file_path, folder_path)) |
def tracking_objects(session_id, seg_tracker, frame_num, input_video): |
output_dir = f'/tmp/output_frames/{session_id}' |
output_masks_dir = f'/tmp/output_masks/{session_id}' |
output_combined_dir = f'/tmp/output_combined/{session_id}' |
output_files_dir = f'/tmp/output_files/{session_id}' |
output_video_path = f'{output_files_dir}/output_video.mp4' |
output_zip_path = f'{output_files_dir}/output_masks.zip' |
clear_folder(output_masks_dir) |
clear_folder(output_combined_dir) |
clear_folder(output_files_dir) |
video_segments = {} |
predictor, inference_state, image_predictor = seg_tracker |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): |
video_segments[out_frame_idx] = { |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
for i, out_obj_id in enumerate(out_obj_ids) |
} |
frame_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.jpg')]) |
for frame_file in frame_files: |
frame_idx = int(os.path.splitext(frame_file)[0]) |
frame_path = os.path.join(output_dir, frame_file) |
image = cv2.imread(frame_path) |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
masked_frame = image.copy() |
if frame_idx in video_segments: |
for obj_id, mask in video_segments[frame_idx].items(): |
masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id) |
mask_output_path = os.path.join(output_masks_dir, f'{obj_id}_{frame_idx:07d}.png') |
cv2.imwrite(mask_output_path, show_mask(mask)) |
combined_output_path = os.path.join(output_combined_dir, f'{frame_idx:07d}.png') |
combined_image_bgr = cv2.cvtColor(masked_frame, cv2.COLOR_RGB2BGR) |
cv2.imwrite(combined_output_path, combined_image_bgr) |
if frame_idx == frame_num: |
final_masked_frame = masked_frame |
cap = cv2.VideoCapture(input_video) |
fps = cap.get(cv2.CAP_PROP_FPS) |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
cap.release() |
output_frames = len([name for name in os.listdir(output_combined_dir) if os.path.isfile(os.path.join(output_combined_dir, name)) and name.endswith('.png')]) |
out_fps = fps * output_frames / total_frames |
image_files = [os.path.join(output_combined_dir, f'{i:07d}.png') for i in range(output_frames)] |
clip = ImageSequenceClip(image_files, fps=out_fps) |
clip.write_videofile(output_video_path, codec="libx264", fps=out_fps) |
zip_folder(output_masks_dir, output_zip_path) |
print("done") |
return final_masked_frame, final_masked_frame, output_video_path, output_video_path, output_zip_path |
def increment_ann_obj_id(ann_obj_id): |
ann_obj_id += 1 |
return ann_obj_id |
def drawing_board_get_input_first_frame(input_first_frame): |
return input_first_frame |
def process_video(queue, result_queue, session_id): |
seg_tracker = None |
click_stack = ({}, {}) |
frame_num = int(0) |
ann_obj_id =int(0) |
last_draw = None |
while True: |
task = queue.get() |
if task["command"] == "exit": |
print(f"Process for {session_id} exiting.") |
break |
elif task["command"] == "extract_video_info": |
input_video = task["input_video"] |
fps, total_frames, input_first_frame, drawing_board, output_video, output_mp4, output_mask = extract_video_info(input_video) |
result_queue.put({"fps": fps, "total_frames": total_frames, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask}) |
elif task["command"] == "get_meta_from_video": |
input_video = task["input_video"] |
scale_slider = task["scale_slider"] |
checkpoint = task["checkpoint"] |
seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id = get_meta_from_video(session_id, input_video, scale_slider, checkpoint) |
result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id}) |
elif task["command"] == "sam_stroke": |
drawing_board = task["drawing_board"] |
last_draw = task["last_draw"] |
frame_num = task["frame_num"] |
ann_obj_id = task["ann_obj_id"] |
seg_tracker, input_first_frame, drawing_board, last_draw = sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id) |
result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw}) |
elif task["command"] == "sam_click": |
frame_num = task["frame_num"] |
point_mode = task["point_mode"] |
click_stack = task["click_stack"] |
ann_obj_id = task["ann_obj_id"] |
point = task["point"] |
seg_tracker, input_first_frame, drawing_board, last_draw = sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point) |
result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw}) |
elif task["command"] == "increment_ann_obj_id": |
ann_obj_id = task["ann_obj_id"] |
ann_obj_id = increment_ann_obj_id(ann_obj_id) |
result_queue.put({"ann_obj_id": ann_obj_id}) |
elif task["command"] == "drawing_board_get_input_first_frame": |
input_first_frame = task["input_first_frame"] |
input_first_frame = drawing_board_get_input_first_frame(input_first_frame) |
result_queue.put({"input_first_frame": input_first_frame}) |
elif task["command"] == "reset": |
seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id = reset(seg_tracker) |
result_queue.put({"click_stack": click_stack, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id}) |
elif task["command"] == "show_res_by_slider": |
frame_per = task["frame_per"] |
click_stack = task["click_stack"] |
input_first_frame, drawing_board, frame_num = show_res_by_slider(session_id, frame_per, click_stack) |
result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_num": frame_num}) |
elif task["command"] == "tracking_objects": |
frame_num = task["frame_num"] |
input_video = task["input_video"] |
input_first_frame, drawing_board, output_video, output_mp4, output_mask = tracking_objects(session_id, seg_tracker, frame_num, input_video) |
result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask}) |
else: |
print(f"Unknown command {task['command']} for {session_id}") |
result_queue.put("Unknown command") |
def start_process(session_id): |
if session_id not in user_processes: |
queue = mp.Queue() |
result_queue = mp.Queue() |
process = mp.Process(target=process_video, args=(queue, result_queue, session_id)) |
process.start() |
user_processes[session_id] = { |
"process": process, |
"queue": queue, |
"result_queue": result_queue, |
"last_active": datetime.datetime.now() |
} |
else: |
user_processes[session_id]["last_active"] = datetime.datetime.now() |
return user_processes[session_id]["queue"] |
def monitor_and_cleanup_processes(): |
while True: |
now = datetime.datetime.now() |
to_remove = [] |
for session_id, process_info in user_processes.items(): |
if now - process_info["last_active"] > PROCESS_TIMEOUT: |
process_info["queue"].put({"command": "exit"}) |
process_info["process"].terminate() |
process_info["process"].join() |
to_remove.append(session_id) |
for session_id in to_remove: |
del user_processes[session_id] |
print(f"Automatically cleaned up process for session {session_id}.") |
time.sleep(10) |
def seg_track_app(): |
import gradio as gr |
def extract_session_id_from_request(request: gr.Request): |
session_id = hashlib.sha256(f'{request.client.host}:{request.client.port}'.encode('utf-8')).hexdigest() |
print(f"session_id {session_id}") |
return session_id |
def handle_extract_video_info(session_id, input_video): |
if input_video == None: |
return 0, 0, None, None, None, None, None |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "extract_video_info", "input_video": input_video}) |
result = result_queue.get() |
fps = result.get("fps") |
total_frames = result.get("total_frames") |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
output_video = result.get("output_video") |
output_mp4 = result.get("output_mp4") |
output_mask = result.get("output_mask") |
scale_slider = gr.Slider.update(minimum=1.0, |
maximum=fps, |
step=1.0, |
value=fps,) |
frame_per = gr.Slider.update(minimum= 0.0, |
maximum= total_frames / fps, |
step=1.0/fps, |
value=0.0,) |
return scale_slider, frame_per, input_first_frame, drawing_board, output_video, output_mp4, output_mask |
def handle_get_meta_from_video(session_id, input_video, scale_slider, checkpoint): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "get_meta_from_video", "input_video": input_video, "scale_slider": scale_slider, "checkpoint": checkpoint}) |
result = result_queue.get() |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
(fps, frame_interval, total_frames) = result.get("frame_per") |
output_video = result.get("output_video") |
output_mp4 = result.get("output_mp4") |
output_mask = result.get("output_mask") |
ann_obj_id = result.get("ann_obj_id") |
frame_per = gr.Slider.update(minimum= 0.0, |
maximum= total_frames / fps, |
step=frame_interval / fps, |
value=0.0,) |
return input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id |
def handle_sam_stroke(session_id, drawing_board, last_draw, frame_num, ann_obj_id): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "sam_stroke", "drawing_board": drawing_board, "last_draw": last_draw, "frame_num": frame_num, "ann_obj_id": ann_obj_id}) |
result = result_queue.get() |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
last_draw = result.get("last_draw") |
return input_first_frame, drawing_board, last_draw |
def handle_sam_click(session_id, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
point = np.array([[evt.index[0], evt.index[1]]], dtype=np.float32) |
queue.put({"command": "sam_click", "frame_num": frame_num, "point_mode": point_mode, "click_stack": click_stack, "ann_obj_id": ann_obj_id, "point": point}) |
result = result_queue.get() |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
last_draw = result.get("last_draw") |
return input_first_frame, drawing_board, last_draw |
def handle_increment_ann_obj_id(session_id, ann_obj_id): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "increment_ann_obj_id", "ann_obj_id": ann_obj_id}) |
result = result_queue.get() |
ann_obj_id = result.get("ann_obj_id") |
return ann_obj_id |
def handle_drawing_board_get_input_first_frame(session_id, input_first_frame): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "drawing_board_get_input_first_frame", "input_first_frame": input_first_frame}) |
result = result_queue.get() |
input_first_frame = result.get("input_first_frame") |
return input_first_frame |
def handle_reset(session_id): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "reset"}) |
result = result_queue.get() |
click_stack = result.get("click_stack") |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
frame_per = result.get("frame_per") |
output_video = result.get("output_video") |
output_mp4 = result.get("output_mp4") |
output_mask = result.get("output_mask") |
ann_obj_id = result.get("ann_obj_id") |
return click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id |
def handle_show_res_by_slider(session_id, frame_per, click_stack): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "show_res_by_slider", "frame_per": frame_per, "click_stack": click_stack}) |
result = result_queue.get() |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
frame_num = result.get("frame_num") |
return input_first_frame, drawing_board, frame_num |
def handle_tracking_objects(session_id, frame_num, input_video): |
queue = start_process(session_id) |
result_queue = user_processes[session_id]["result_queue"] |
queue.put({"command": "tracking_objects", "frame_num": frame_num, "input_video": input_video}) |
result = result_queue.get() |
input_first_frame = result.get("input_first_frame") |
drawing_board = result.get("drawing_board") |
output_video = result.get("output_video") |
output_mp4 = result.get("output_mp4") |
output_mask = result.get("output_mask") |
return input_first_frame, drawing_board, output_video, output_mp4, output_mask |
css = """ |
#input_output_video video { |
max-height: 550px; |
max-width: 100%; |
height: auto; |
} |
""" |
app = gr.Blocks(css=css) |
with app: |
session_id = gr.State() |
app.load(extract_session_id_from_request, None, session_id) |
gr.Markdown( |
''' |
<div style="text-align:center; margin-bottom:20px;"> |
<span style="font-size:3em; font-weight:bold;">MedSAM2 for Video Segmentation 🔥</span> |
</div> |
<div style="text-align:center; margin-bottom:10px;"> |
<span style="font-size:1.5em; font-weight:bold;">MedSAM2-Segment Anything in Medical Images and Videos: Benchmark and Deployment</span> |
</div> |
<div style="text-align:center; margin-bottom:20px;"> |
<a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2"> |
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;"> |
</a> |
<a href="https://arxiv.org/abs/2408.03322"> |
<img src="https://img.shields.io/badge/arXiv-2408.03322-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;"> |
</a> |
<a href="https://github.com/bowang-lab/MedSAMSlicer/tree/SAM2"> |
<img src="https://img.shields.io/badge/3D-Slicer-Plugin" alt="3D Slicer Plugin" style="display:inline-block; margin-right:10px;"> |
</a> |
<a href="https://drive.google.com/drive/folders/1EXzRkxZmrXbahCFA8_ImFRM6wQDEpOSe?usp=sharing"> |
<img src="https://img.shields.io/badge/Video-Tutorial-green?style=plastic" alt="Video Tutorial" style="display:inline-block; margin-right:10px;"> |
</a> |
<a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2?tab=readme-ov-file#fine-tune-sam2-on-the-abdomen-ct-dataset"> |
<img src="https://img.shields.io/badge/Fine--tune-SAM2-blue" alt="Fine-tune SAM2" style="display:inline-block; margin-right:10px;"> |
</a> |
</div> |
<div style="text-align:left; margin-bottom:20px;"> |
This API supports using box (generated by scribble) and point prompts for video segmentation with |
<a href="https://ai.meta.com/sam2/" target="_blank">SAM2</a>. Welcome to join our <a href="https://forms.gle/hk4Efp6uWnhjUHFP6" target="_blank">mailing list</a> to get updates or send feedback. |
</div> |
<div style="margin-bottom:20px;"> |
<ol style="list-style:none; padding-left:0;"> |
<li>1. Upload video file</li> |
<li>2. Select model size and downsample frame rate and run <b>Preprocess</b></li> |
<li>3. Use <b>Stroke to Box Prompt</b> to draw box on the first frame or <b>Point Prompt</b> to click on the first frame.</li> |
<li> Note: The bounding rectangle of the stroke should be able to cover the segmentation target.</li> |
<li>4. Click <b>Segment</b> to get the segmentation result</li> |
<li>5. Click <b>Add New Object</b> to add new object</li> |
<li>6. Click <b>Start Tracking</b> to track objects in the video</li> |
<li>7. Click <b>Reset</b> to reset the app</li> |
<li>8. Download the video with segmentation results</li> |
</ol> |
</div> |
<div style="text-align:left; line-height:1.8;"> |
We designed this API and <a href="https://github.com/bowang-lab/MedSAMSlicer/tree/SAM2" target="_blank">3D Slicer Plugin</a> for medical image and video segmentation where the checkpoints are based on the original SAM2 models (<a href="https://github.com/facebookresearch/segment-anything-2" target="_blank">https://github.com/facebookresearch/segment-anything-2</a>). The image segmentation fine-tune code has been released on <a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2?tab=readme-ov-file#fine-tune-sam2-on-the-abdomen-ct-dataset" target="_blank">GitHub</a>. The video fine-tuning code is under active development and will be released as well. |
</div> |
<div style="text-align:left; line-height:1.8;"> |
If you find these tools useful, please consider citing the following papers: |
</div> |
<div style="text-align:left; line-height:1.8;"> |
Ravi, N., Gabeur, V., Hu, Y.T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädle, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K.V., Carion, N., Wu, C.Y., Girshick, R., Dollár, P., Feichtenhofer, C.: SAM 2: Segment Anything in Images and Videos. arXiv:2408.00714 (2024) |
</div> |
<div style="text-align:left; line-height:1.8;"> |
Ma, J., Kim, S., Li, F., Baharoon, M., Asakereh, R., Lyu, H., Wang, B.: Segment Anything in Medical Images and Videos: Benchmark and Deployment. arXiv preprint arXiv:2408.03322 (2024) |
</div> |
<div style="text-align:left; line-height:1.8;"> |
Other useful resources: |
<a href="https://ai.meta.com/sam2" target="_blank">Official demo</a> from MetaAI, |
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y" target="_blank">Video tutorial</a> from Piotr Skalski. |
</div> |
''' |
) |
click_stack = gr.State(({}, {})) |
frame_num = gr.State(value=(int(0))) |
ann_obj_id = gr.State(value=(int(0))) |
last_draw = gr.State(None) |
with gr.Row(): |
with gr.Column(scale=0.5): |
with gr.Row(): |
tab_video_input = gr.Tab(label="Video input") |
with tab_video_input: |
input_video = gr.Video(label='Input video', type=["mp4", "mov", "avi"], elem_id="input_output_video") |
with gr.Row(): |
checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny") |
scale_slider = gr.Slider( |
label="Downsampe Frame Rate (fps)", |
minimum=0.0, |
maximum=1.0, |
step=0.25, |
value=1.0, |
interactive=True |
) |
preprocess_button = gr.Button( |
value="Preprocess", |
interactive=True, |
) |
with gr.Row(): |
tab_stroke = gr.Tab(label="Stroke to Box Prompt") |
with tab_stroke: |
drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True) |
with gr.Row(): |
seg_acc_stroke = gr.Button(value="Segment", interactive=True) |
tab_click = gr.Tab(label="Point Prompt") |
with tab_click: |
input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550) |
with gr.Row(): |
point_mode = gr.Radio( |
choices=["Positive", "Negative"], |
value="Positive", |
label="Point Prompt", |
interactive=True) |
with gr.Row(): |
with gr.Column(): |
frame_per = gr.Slider( |
label = "Time (seconds)", |
minimum= 0.0, |
maximum= 100.0, |
step=0.01, |
value=0.0, |
) |
new_object_button = gr.Button( |
value="Add New Object", |
interactive=True |
) |
track_for_video = gr.Button( |
value="Start Tracking", |
interactive=True, |
) |
reset_button = gr.Button( |
value="Reset", |
interactive=True, |
) |
with gr.Column(scale=0.5): |
output_video = gr.Video(label='Visualize Results', elem_id="input_output_video") |
output_mp4 = gr.File(label="Predicted video") |
output_mask = gr.File(label="Predicted masks") |
with gr.Tab(label='Video examples'): |
gr.Examples( |
label="", |
examples=[ |
"assets/12fps_Dancing_cells_trimmed.mp4", |
"assets/clip_012251_fps5_07_25.mp4", |
"assets/FLARE22_Tr_0004.mp4", |
"assets/c_elegans_mov_cut_fps12.mp4", |
], |
inputs=[input_video], |
) |
gr.Examples( |
label="", |
examples=[ |
"assets/12fps_volvox_microcystis_play_trimmed.mp4", |
"assets/12fps_macrophages_phagocytosis.mp4", |
"assets/12fps_worm_eats_organism_5.mp4", |
"assets/12fps_worm_eats_organism_6.mp4", |
"assets/12fps_02_cups.mp4", |
], |
inputs=[input_video], |
) |
gr.Markdown( |
''' |
<div style="text-align:center; margin-top: 20px;"> |
The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community. |
The interface was built on <a href="https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/tutorial/tutorial%20for%20WebUI-1.0-Version.md" target="_blank">SegTracker</a>, which is also an amazing tool for video segmentation tracking. |
<a href="https://docs.google.com/document/d/1idDBV0faOjdjVs-iAHr0uSrw_9_ZzLGrUI2FEdK-lso/edit?usp=sharing" target="_blank">Data source</a> |
</div> |
''' |
) |
preprocess_button.click( |
fn=handle_get_meta_from_video, |
inputs=[ |
session_id, |
input_video, |
scale_slider, |
checkpoint |
], |
outputs=[ |
input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id |
] |
) |
frame_per.release( |
fn=handle_show_res_by_slider, |
inputs=[ |
session_id, frame_per, click_stack |
], |
outputs=[ |
input_first_frame, drawing_board, frame_num |
] |
) |
input_first_frame.select( |
fn=handle_sam_click, |
inputs=[ |
session_id, frame_num, point_mode, click_stack, ann_obj_id |
], |
outputs=[ |
input_first_frame, drawing_board, click_stack |
] |
) |
track_for_video.click( |
fn=handle_tracking_objects, |
inputs=[ |
session_id, |
frame_num, |
input_video, |
], |
outputs=[ |
input_first_frame, |
drawing_board, |
output_video, |
output_mp4, |
output_mask |
] |
) |
reset_button.click( |
fn=handle_reset, |
inputs=[session_id], |
outputs=[ |
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id |
] |
) |
new_object_button.click( |
fn=handle_increment_ann_obj_id, |
inputs=[ |
session_id, ann_obj_id |
], |
outputs=[ |
ann_obj_id |
] |
) |
tab_stroke.select( |
fn=handle_drawing_board_get_input_first_frame, |
inputs=[session_id, input_first_frame], |
outputs=[drawing_board,], |
) |
seg_acc_stroke.click( |
fn=handle_sam_stroke, |
inputs=[ |
session_id, drawing_board, last_draw, frame_num, ann_obj_id |
], |
outputs=[ |
input_first_frame, drawing_board, last_draw |
] |
) |
input_video.change( |
fn=handle_extract_video_info, |
inputs=[session_id, input_video], |
outputs=[scale_slider, frame_per, input_first_frame, drawing_board, output_video, output_mp4, output_mask] |
) |
app.queue(concurrency_count=1) |
app.launch(debug=True, enable_queue=True, share=False) |
if __name__ == "__main__": |
mp.set_start_method("spawn") |
monitor_thread = threading.Thread(target=monitor_and_cleanup_processes) |
monitor_thread.daemon = True |
monitor_thread.start() |
seg_track_app() |