Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
from sam2.build_sam import build_sam2_video_predictor | |
import tempfile | |
import os | |
import contextlib | |
class VideoTracker: | |
def __init__(self): | |
self.checkpoint = "./models/sam2.1_hiera_tiny.pt" | |
self.model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
self.predictor = build_sam2_video_predictor( | |
self.model_cfg, self.checkpoint, device="cpu", mode="eval" | |
) | |
self.state = None | |
self.video_frames = None | |
self.current_frame_idx = 0 | |
self.masks = [] | |
self.points = [] | |
self.frame_count = 0 | |
self.video_info = None | |
self.obj_id = 1 | |
self.out_mask_logits = None | |
self.frame_masks = {} # Store masks for each frame | |
def load_video(self, video_path): | |
if video_path is None: | |
return None, gr.Slider(minimum=0, maximum=0, step=1, value=0) | |
# Create a temporary file for the video | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
temp_file.close() | |
# Copy the uploaded video to the temporary file | |
with open(video_path, "rb") as f_src, open(temp_file.name, "wb") as f_dst: | |
f_dst.write(f_src.read()) | |
# Load video frames using OpenCV | |
cap = cv2.VideoCapture(temp_file.name) | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frames.append(frame) | |
if not frames: | |
cap.release() | |
os.unlink(temp_file.name) | |
return None, gr.Slider(minimum=0, maximum=0, step=1, value=0) | |
# Store video info | |
self.video_info = { | |
"path": temp_file.name, | |
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), | |
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
"fps": cap.get(cv2.CAP_PROP_FPS), | |
"total_frames": len(frames), | |
} | |
cap.release() | |
self.video_frames = frames | |
self.frame_count = len(frames) | |
# Initialize SAM2 state with video path | |
with torch.inference_mode(): | |
self.state = self.predictor.init_state(temp_file.name) | |
# Now we can remove the temp file | |
os.unlink(temp_file.name) | |
return frames[0], gr.Slider(minimum=0, maximum=len(frames) - 1, step=1, value=0) | |
def update_frame(self, frame_number): | |
if self.video_frames is None: | |
return None | |
self.current_frame_idx = frame_number | |
frame = self.video_frames[frame_number].copy() | |
# Apply any existing mask for this frame | |
if frame_number in self.frame_masks: | |
self.out_mask_logits = self.frame_masks[frame_number] | |
frame = self._draw_tracking(frame) | |
# Draw points | |
for point in self.points: | |
if point[0] == frame_number: | |
cv2.circle( | |
frame, (int(point[1]), int(point[2])), 5, (255, 255, 0), -1 | |
) # Yellow dot | |
cv2.circle( | |
frame, (int(point[1]), int(point[2])), 7, (0, 0, 0), 1 | |
) # Black border | |
return frame | |
def add_point(self, frame, evt: gr.SelectData): | |
"""Add a point and get ball prediction with enhanced mask visualization""" | |
if self.state is None: | |
return frame | |
x, y = evt.index[0], evt.index[1] | |
self.points.append((self.current_frame_idx, x, y)) | |
frame_with_points = frame.copy() | |
# Get ball prediction using SAM2.1 | |
with torch.inference_mode(): | |
# Convert points and labels to numpy arrays | |
points = np.array([(x, y)], dtype=np.float32) | |
labels = np.array([1], dtype=np.int32) # 1 for positive click | |
# Add point and get mask | |
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points( | |
inference_state=self.state, | |
frame_idx=self.current_frame_idx, | |
obj_id=self.obj_id, | |
points=points, | |
labels=labels, | |
) | |
if out_mask_logits is not None and len(out_mask_logits) > 0: | |
self.out_mask_logits = ( | |
out_mask_logits[0] | |
if isinstance(out_mask_logits, list) | |
else out_mask_logits | |
) | |
# Store mask for this frame | |
self.frame_masks[self.current_frame_idx] = self.out_mask_logits | |
# Draw tracking visualization with enhanced mask | |
frame_with_points = self._draw_tracking(frame_with_points) | |
# Draw point on top of mask | |
cv2.circle( | |
frame_with_points, (int(x), int(y)), 5, (255, 255, 0), -1 | |
) # Yellow dot | |
cv2.circle(frame_with_points, (int(x), int(y)), 7, (0, 0, 0), 1) # Black border | |
return frame_with_points | |
def propagate_video(self): | |
if self.state is None: | |
return None | |
output_frames = self.video_frames.copy() | |
# Store all masks for smoother visualization | |
all_masks = [] | |
# First pass: collect all masks | |
with torch.inference_mode(): | |
for frame_idx, obj_ids, masks in self.predictor.propagate_in_video( | |
self.state, | |
start_frame_idx=0, | |
reverse=False, | |
): | |
if masks is not None and len(masks) > 0: | |
mask = masks[0] if isinstance(masks, list) else masks | |
all_masks.append((frame_idx, mask)) | |
# Store mask for each frame | |
self.frame_masks[frame_idx] = mask | |
# Second pass: apply visualization with temporal smoothing | |
for i, frame in enumerate(output_frames): | |
frame = frame.copy() | |
# Find masks for this frame | |
current_masks = [m[1] for m in all_masks if m[0] == i] | |
if current_masks: | |
self.out_mask_logits = current_masks[0] | |
# Get binary mask and ensure correct dimensions | |
mask_np = (current_masks[0] > 0.0).cpu().numpy() | |
mask_np = self._handle_mask_dimensions(mask_np) | |
# Convert to proper format for OpenCV | |
mask_np = mask_np.astype(np.uint8) | |
# Enhanced visualization for video | |
frame = self._draw_tracking(frame, alpha=0.6) | |
# Create glowing effect | |
try: | |
# Create kernel for dilation | |
kernel = np.ones((5, 5), np.uint8) | |
# Dilate mask for glow effect | |
dilated_mask = cv2.dilate(mask_np, kernel, iterations=2) | |
# Create glow overlay | |
glow = frame.copy() | |
glow[dilated_mask > 0] = [0, 255, 255] # Yellow glow | |
# Blend glow with frame | |
frame = cv2.addWeighted(frame, 0.7, glow, 0.3, 0) | |
except cv2.error as e: | |
print( | |
f"Warning: Could not apply glow effect. Mask shape: {mask_np.shape}, Frame shape: {frame.shape}" | |
) | |
# Continue without glow effect if there's an error | |
output_frames[i] = frame | |
# Save as video with higher quality | |
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
height, width = output_frames[0].shape[:2] | |
# Use higher bitrate for better quality | |
writer = cv2.VideoWriter( | |
temp_output, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height), True | |
) | |
for frame in output_frames: | |
writer.write(frame) | |
writer.release() | |
return temp_output | |
def _handle_mask_dimensions(self, mask_np): | |
"""Helper function to handle various mask dimensions""" | |
# Handle 4D tensor (1, 1, H, W) | |
if len(mask_np.shape) == 4: | |
mask_np = mask_np[0, 0] | |
# Handle 3D tensor (1, H, W) or (H, W, 1) | |
elif len(mask_np.shape) == 3: | |
if mask_np.shape[0] == 1: # (1, H, W) format | |
mask_np = mask_np[0] | |
elif mask_np.shape[2] == 1: # (H, W, 1) format | |
mask_np = mask_np[:, :, 0] | |
return mask_np | |
def _draw_tracking(self, frame, alpha=0.5): | |
"""Draw object mask on frame with enhanced visualization""" | |
if self.out_mask_logits is not None: | |
# Convert logits to binary mask | |
if isinstance(self.out_mask_logits, list): | |
mask = self.out_mask_logits[0] | |
else: | |
mask = self.out_mask_logits | |
# Get binary mask and handle dimensions | |
mask_np = (mask > 0.0).cpu().numpy() | |
mask_np = self._handle_mask_dimensions(mask_np) | |
if mask_np.shape[:2] == frame.shape[:2]: | |
# Create a red overlay for the mask | |
overlay = frame.copy() | |
overlay[mask_np > 0] = [0, 0, 255] # BGR format: Red color | |
# Add a border around the mask for better visibility | |
contours, _ = cv2.findContours( | |
mask_np.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
) | |
# Draw thicker contours for better visibility | |
cv2.drawContours( | |
overlay, contours, -1, (0, 255, 255), 3 | |
) # Thicker yellow border | |
# Add a second contour for emphasis | |
cv2.drawContours( | |
frame, contours, -1, (255, 255, 0), 1 | |
) # Thin bright border | |
# Blend the overlay with original frame | |
frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) | |
return frame | |
def create_interface(): | |
tracker = VideoTracker() | |
with gr.Blocks() as interface: | |
gr.Markdown("# Object Tracking with SAM2") | |
gr.Markdown("Upload a video and click on objects to track them") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
video_input = gr.Video(label="Input Video") | |
image_output = gr.Image(label="Current Frame", interactive=True) | |
frame_slider = gr.Slider( | |
minimum=0, | |
maximum=0, | |
step=1, | |
value=0, | |
label="Frame Selection", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
propagate_btn = gr.Button("Propagate Through Video", variant="primary") | |
video_output = gr.Video(label="Output Video") | |
video_input.change( | |
fn=tracker.load_video, | |
inputs=[video_input], | |
outputs=[image_output, frame_slider], | |
) | |
frame_slider.change( | |
fn=tracker.update_frame, | |
inputs=[frame_slider], | |
outputs=[image_output], | |
) | |
image_output.select( | |
fn=tracker.add_point, | |
inputs=[image_output], | |
outputs=[image_output], | |
) | |
propagate_btn.click( | |
fn=tracker.propagate_video, | |
inputs=[], | |
outputs=[video_output], | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch(share=True) | |