import gradio as gr import torch import numpy as np import cv2 from sam2.build_sam import build_sam2_video_predictor import tempfile import os import contextlib class VideoTracker: def __init__(self): self.checkpoint = "./models/sam2.1_hiera_tiny.pt" self.model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" self.predictor = build_sam2_video_predictor( self.model_cfg, self.checkpoint, device="cpu", mode="eval" ) self.state = None self.video_frames = None self.current_frame_idx = 0 self.masks = [] self.points = [] self.frame_count = 0 self.video_info = None self.obj_id = 1 self.out_mask_logits = None self.frame_masks = {} # Store masks for each frame def load_video(self, video_path): if video_path is None: return None, gr.Slider(minimum=0, maximum=0, step=1, value=0) # Create a temporary file for the video temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") temp_file.close() # Copy the uploaded video to the temporary file with open(video_path, "rb") as f_src, open(temp_file.name, "wb") as f_dst: f_dst.write(f_src.read()) # Load video frames using OpenCV cap = cv2.VideoCapture(temp_file.name) frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) if not frames: cap.release() os.unlink(temp_file.name) return None, gr.Slider(minimum=0, maximum=0, step=1, value=0) # Store video info self.video_info = { "path": temp_file.name, "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), "fps": cap.get(cv2.CAP_PROP_FPS), "total_frames": len(frames), } cap.release() self.video_frames = frames self.frame_count = len(frames) # Initialize SAM2 state with video path with torch.inference_mode(): self.state = self.predictor.init_state(temp_file.name) # Now we can remove the temp file os.unlink(temp_file.name) return frames[0], gr.Slider(minimum=0, maximum=len(frames) - 1, step=1, value=0) def update_frame(self, frame_number): if self.video_frames is None: return None self.current_frame_idx = frame_number frame = self.video_frames[frame_number].copy() # Apply any existing mask for this frame if frame_number in self.frame_masks: self.out_mask_logits = self.frame_masks[frame_number] frame = self._draw_tracking(frame) # Draw points for point in self.points: if point[0] == frame_number: cv2.circle( frame, (int(point[1]), int(point[2])), 5, (255, 255, 0), -1 ) # Yellow dot cv2.circle( frame, (int(point[1]), int(point[2])), 7, (0, 0, 0), 1 ) # Black border return frame def add_point(self, frame, evt: gr.SelectData): """Add a point and get ball prediction with enhanced mask visualization""" if self.state is None: return frame x, y = evt.index[0], evt.index[1] self.points.append((self.current_frame_idx, x, y)) frame_with_points = frame.copy() # Get ball prediction using SAM2.1 with torch.inference_mode(): # Convert points and labels to numpy arrays points = np.array([(x, y)], dtype=np.float32) labels = np.array([1], dtype=np.int32) # 1 for positive click # Add point and get mask _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( inference_state=self.state, frame_idx=self.current_frame_idx, obj_id=self.obj_id, points=points, labels=labels, ) if out_mask_logits is not None and len(out_mask_logits) > 0: self.out_mask_logits = ( out_mask_logits[0] if isinstance(out_mask_logits, list) else out_mask_logits ) # Store mask for this frame self.frame_masks[self.current_frame_idx] = self.out_mask_logits # Draw tracking visualization with enhanced mask frame_with_points = self._draw_tracking(frame_with_points) # Draw point on top of mask cv2.circle( frame_with_points, (int(x), int(y)), 5, (255, 255, 0), -1 ) # Yellow dot cv2.circle(frame_with_points, (int(x), int(y)), 7, (0, 0, 0), 1) # Black border return frame_with_points def propagate_video(self): if self.state is None: return None output_frames = self.video_frames.copy() # Store all masks for smoother visualization all_masks = [] # First pass: collect all masks with torch.inference_mode(): for frame_idx, obj_ids, masks in self.predictor.propagate_in_video( self.state, start_frame_idx=0, reverse=False, ): if masks is not None and len(masks) > 0: mask = masks[0] if isinstance(masks, list) else masks all_masks.append((frame_idx, mask)) # Store mask for each frame self.frame_masks[frame_idx] = mask # Second pass: apply visualization with temporal smoothing for i, frame in enumerate(output_frames): frame = frame.copy() # Find masks for this frame current_masks = [m[1] for m in all_masks if m[0] == i] if current_masks: self.out_mask_logits = current_masks[0] # Get binary mask and ensure correct dimensions mask_np = (current_masks[0] > 0.0).cpu().numpy() mask_np = self._handle_mask_dimensions(mask_np) # Convert to proper format for OpenCV mask_np = mask_np.astype(np.uint8) # Enhanced visualization for video frame = self._draw_tracking(frame, alpha=0.6) # Create glowing effect try: # Create kernel for dilation kernel = np.ones((5, 5), np.uint8) # Dilate mask for glow effect dilated_mask = cv2.dilate(mask_np, kernel, iterations=2) # Create glow overlay glow = frame.copy() glow[dilated_mask > 0] = [0, 255, 255] # Yellow glow # Blend glow with frame frame = cv2.addWeighted(frame, 0.7, glow, 0.3, 0) except cv2.error as e: print( f"Warning: Could not apply glow effect. Mask shape: {mask_np.shape}, Frame shape: {frame.shape}" ) # Continue without glow effect if there's an error output_frames[i] = frame # Save as video with higher quality temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name height, width = output_frames[0].shape[:2] # Use higher bitrate for better quality writer = cv2.VideoWriter( temp_output, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height), True ) for frame in output_frames: writer.write(frame) writer.release() return temp_output def _handle_mask_dimensions(self, mask_np): """Helper function to handle various mask dimensions""" # Handle 4D tensor (1, 1, H, W) if len(mask_np.shape) == 4: mask_np = mask_np[0, 0] # Handle 3D tensor (1, H, W) or (H, W, 1) elif len(mask_np.shape) == 3: if mask_np.shape[0] == 1: # (1, H, W) format mask_np = mask_np[0] elif mask_np.shape[2] == 1: # (H, W, 1) format mask_np = mask_np[:, :, 0] return mask_np def _draw_tracking(self, frame, alpha=0.5): """Draw object mask on frame with enhanced visualization""" if self.out_mask_logits is not None: # Convert logits to binary mask if isinstance(self.out_mask_logits, list): mask = self.out_mask_logits[0] else: mask = self.out_mask_logits # Get binary mask and handle dimensions mask_np = (mask > 0.0).cpu().numpy() mask_np = self._handle_mask_dimensions(mask_np) if mask_np.shape[:2] == frame.shape[:2]: # Create a red overlay for the mask overlay = frame.copy() overlay[mask_np > 0] = [0, 0, 255] # BGR format: Red color # Add a border around the mask for better visibility contours, _ = cv2.findContours( mask_np.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # Draw thicker contours for better visibility cv2.drawContours( overlay, contours, -1, (0, 255, 255), 3 ) # Thicker yellow border # Add a second contour for emphasis cv2.drawContours( frame, contours, -1, (255, 255, 0), 1 ) # Thin bright border # Blend the overlay with original frame frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) return frame def create_interface(): tracker = VideoTracker() with gr.Blocks() as interface: gr.Markdown("# Object Tracking with SAM2") gr.Markdown("Upload a video and click on objects to track them") with gr.Row(): with gr.Column(scale=2): video_input = gr.Video(label="Input Video") image_output = gr.Image(label="Current Frame", interactive=True) frame_slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Frame Selection", interactive=True, ) with gr.Column(scale=1): propagate_btn = gr.Button("Propagate Through Video", variant="primary") video_output = gr.Video(label="Output Video") video_input.change( fn=tracker.load_video, inputs=[video_input], outputs=[image_output, frame_slider], ) frame_slider.change( fn=tracker.update_frame, inputs=[frame_slider], outputs=[image_output], ) image_output.select( fn=tracker.add_point, inputs=[image_output], outputs=[image_output], ) propagate_btn.click( fn=tracker.propagate_video, inputs=[], outputs=[video_output], ) return interface if __name__ == "__main__": interface = create_interface() interface.launch(share=True)