golf_tracking / working.py
rehctiw25's picture
Upload folder using huggingface_hub
013216e verified
import gradio as gr
import torch
import numpy as np
import cv2
from sam2.build_sam import build_sam2_video_predictor
import tempfile
import os
import contextlib
class VideoTracker:
def __init__(self):
self.checkpoint = "./models/sam2.1_hiera_tiny.pt"
self.model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
self.predictor = build_sam2_video_predictor(
self.model_cfg, self.checkpoint, device="cpu", mode="eval"
)
self.state = None
self.video_frames = None
self.current_frame_idx = 0
self.masks = []
self.points = []
self.frame_count = 0
self.video_info = None
self.obj_id = 1
self.out_mask_logits = None
def load_video(self, video_path):
if video_path is None:
return None, gr.Slider(minimum=0, maximum=0, step=1, value=0)
# Create a temporary file for the video
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_file.close()
# Copy the uploaded video to the temporary file
with open(video_path, "rb") as f_src, open(temp_file.name, "wb") as f_dst:
f_dst.write(f_src.read())
# Load video frames using OpenCV
cap = cv2.VideoCapture(temp_file.name)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
if not frames:
cap.release()
os.unlink(temp_file.name)
return None, gr.Slider(minimum=0, maximum=0, step=1, value=0)
# Store video info
self.video_info = {
"path": temp_file.name,
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
"fps": cap.get(cv2.CAP_PROP_FPS),
"total_frames": len(frames),
}
cap.release()
self.video_frames = frames
self.frame_count = len(frames)
# Initialize SAM2 state with video path
with torch.inference_mode():
self.state = self.predictor.init_state(temp_file.name)
# Now we can remove the temp file
os.unlink(temp_file.name)
return frames[0], gr.Slider(minimum=0, maximum=len(frames) - 1, step=1, value=0)
def update_frame(self, frame_number):
if self.video_frames is None:
return None
self.current_frame_idx = frame_number
frame = self.video_frames[frame_number].copy()
# Draw existing points and masks
if self.out_mask_logits is not None:
frame = self._draw_tracking(frame)
# Draw points
for point in self.points:
if point[0] == frame_number:
cv2.circle(frame, (int(point[1]), int(point[2])), 5, (255, 0, 0), -1)
return frame
def add_point(self, frame, evt: gr.SelectData):
"""Add a point and get ball prediction with enhanced mask visualization"""
if self.state is None:
return frame
x, y = evt.index[0], evt.index[1]
self.points.append((self.current_frame_idx, x, y))
frame_with_points = frame.copy()
# Get ball prediction using SAM2.1
with torch.inference_mode():
# Convert points and labels to numpy arrays
points = np.array([(x, y)], dtype=np.float32)
labels = np.array([1], dtype=np.int32) # 1 for positive click
# Add point and get mask
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=self.state,
frame_idx=self.current_frame_idx,
obj_id=self.obj_id,
points=points,
labels=labels,
)
if out_mask_logits is not None and len(out_mask_logits) > 0:
self.out_mask_logits = (
out_mask_logits[0]
if isinstance(out_mask_logits, list)
else out_mask_logits
)
# Draw tracking visualization with enhanced mask
frame_with_points = self._draw_tracking(frame_with_points)
# Draw point on top of mask
cv2.circle(
frame_with_points, (int(x), int(y)), 5, (255, 255, 0), -1
) # Yellow dot
cv2.circle(frame_with_points, (int(x), int(y)), 7, (0, 0, 0), 1) # Black border
return frame_with_points
def propagate_video(self):
if self.state is None:
return None
output_frames = self.video_frames.copy()
# Store all masks for smoother visualization
all_masks = []
# First pass: collect all masks
with torch.inference_mode():
for frame_idx, obj_ids, masks in self.predictor.propagate_in_video(
self.state,
start_frame_idx=0,
reverse=False,
):
if masks is not None and len(masks) > 0:
mask = masks[0] if isinstance(masks, list) else masks
all_masks.append((frame_idx, mask))
# Second pass: apply visualization with temporal smoothing
for i, frame in enumerate(output_frames):
frame = frame.copy()
# Find masks for this frame
current_masks = [m[1] for m in all_masks if m[0] == i]
if current_masks:
self.out_mask_logits = current_masks[0]
# Get binary mask and ensure correct dimensions
mask_np = (current_masks[0] > 0.0).cpu().numpy()
mask_np = self._handle_mask_dimensions(mask_np)
# Convert to proper format for OpenCV
mask_np = mask_np.astype(np.uint8)
# Enhanced visualization for video
frame = self._draw_tracking(frame, alpha=0.6)
# Create glowing effect
try:
# Create kernel for dilation
kernel = np.ones((5, 5), np.uint8)
# Dilate mask for glow effect
dilated_mask = cv2.dilate(mask_np, kernel, iterations=2)
# Create glow overlay
glow = frame.copy()
glow[dilated_mask > 0] = [0, 255, 255] # Yellow glow
# Blend glow with frame
frame = cv2.addWeighted(frame, 0.7, glow, 0.3, 0)
except cv2.error as e:
print(
f"Warning: Could not apply glow effect. Mask shape: {mask_np.shape}, Frame shape: {frame.shape}"
)
# Continue without glow effect if there's an error
output_frames[i] = frame
# Save as video with higher quality
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
height, width = output_frames[0].shape[:2]
# Use higher bitrate for better quality
writer = cv2.VideoWriter(
temp_output, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height), True
)
for frame in output_frames:
writer.write(frame)
writer.release()
return temp_output
def _handle_mask_dimensions(self, mask_np):
"""Helper function to handle various mask dimensions"""
# Handle 4D tensor (1, 1, H, W)
if len(mask_np.shape) == 4:
mask_np = mask_np[0, 0]
# Handle 3D tensor (1, H, W) or (H, W, 1)
elif len(mask_np.shape) == 3:
if mask_np.shape[0] == 1: # (1, H, W) format
mask_np = mask_np[0]
elif mask_np.shape[2] == 1: # (H, W, 1) format
mask_np = mask_np[:, :, 0]
return mask_np
def _draw_tracking(self, frame, alpha=0.5):
"""Draw object mask on frame with enhanced visualization"""
if self.out_mask_logits is not None:
# Convert logits to binary mask
if isinstance(self.out_mask_logits, list):
mask = self.out_mask_logits[0]
else:
mask = self.out_mask_logits
# Get binary mask and handle dimensions
mask_np = (mask > 0.0).cpu().numpy()
mask_np = self._handle_mask_dimensions(mask_np)
if mask_np.shape[:2] == frame.shape[:2]:
# Create a red overlay for the mask
overlay = frame.copy()
overlay[mask_np > 0] = [0, 0, 255] # BGR format: Red color
# Add a border around the mask for better visibility
contours, _ = cv2.findContours(
mask_np.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# Draw thicker contours for better visibility
cv2.drawContours(
overlay, contours, -1, (0, 255, 255), 3
) # Thicker yellow border
# Add a second contour for emphasis
cv2.drawContours(
frame, contours, -1, (255, 255, 0), 1
) # Thin bright border
# Blend the overlay with original frame
frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
return frame
def create_interface():
tracker = VideoTracker()
with gr.Blocks() as interface:
gr.Markdown("# Object Tracking with SAM2")
gr.Markdown("Upload a video and click on objects to track them")
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(label="Input Video")
image_output = gr.Image(label="Current Frame", interactive=True)
frame_slider = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label="Frame Selection",
interactive=True,
)
with gr.Column(scale=1):
propagate_btn = gr.Button("Propagate Through Video", variant="primary")
video_output = gr.Video(label="Output Video")
video_input.change(
fn=tracker.load_video,
inputs=[video_input],
outputs=[image_output, frame_slider],
)
frame_slider.change(
fn=tracker.update_frame,
inputs=[frame_slider],
outputs=[image_output],
)
image_output.select(
fn=tracker.add_point,
inputs=[image_output],
outputs=[image_output],
)
propagate_btn.click(
fn=tracker.propagate_video,
inputs=[],
outputs=[video_output],
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True)