golf_tracking / app.py
rehctiw25's picture
Upload folder using huggingface_hub
013216e verified
import gradio as gr
import numpy as np
import cv2
import torch
import os
import logging
import contextlib
from sam2.build_sam import build_sam2_video_predictor
# Add current directory to path
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), "sam2")) # Add sam2 directory to path
print(f"current dir is {os.getcwd()}")
# Ensure device setup matches the official code
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
if force_cpu_device:
logging.info("forcing CPU device for SAM 2 demo")
if torch.cuda.is_available() and not force_cpu_device:
DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available() and not force_cpu_device:
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
logging.info(f"using device: {DEVICE}")
if DEVICE.type == "cuda":
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif DEVICE.type == "mps":
logging.warning(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)
def load_model_paths(checkpoint_name):
"""Get model checkpoint and config paths"""
if checkpoint_name == "SAM2-T":
sam2_checkpoint = "models/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
elif checkpoint_name == "SAM2-S":
sam2_checkpoint = "models/sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
elif checkpoint_name == "SAM2-B_PLUS":
sam2_checkpoint = "models/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
else:
raise ValueError(f"Invalid checkpoint name: {checkpoint_name}")
return sam2_checkpoint, model_cfg
# Available checkpoints
CHECKPOINTS = {
"SAM2-B_PLUS": "Base Plus Model",
"SAM2-S": "Small Model",
"SAM2-T": "Tiny Model",
}
class GolfTracker:
def __init__(self, checkpoint="SAM2-T"):
"""Initialize with specified checkpoint model"""
self.current_checkpoint = checkpoint
self.predictor = None
self.points = []
self.frames = []
self.current_frame_idx = 0
self.video_info = None
self.state = None
self.obj_id = 1 # Track single object (golf ball)
self.device = DEVICE
self.out_mask_logits = None
self.load_model(checkpoint)
def load_model(self, checkpoint_name):
"""Load specified checkpoint model"""
if checkpoint_name not in CHECKPOINTS:
raise ValueError(f"Invalid checkpoint: {checkpoint_name}")
print(f"Loading checkpoint: {checkpoint_name}")
sam2_checkpoint, model_cfg = load_model_paths(checkpoint_name)
# Build predictor with model config and checkpoint
self.predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, self.device
)
print(f"Model loaded successfully: {CHECKPOINTS[checkpoint_name]}")
self.current_checkpoint = checkpoint_name
def process_video(self, video_path):
"""Process the video and initialize tracking"""
if not os.path.exists(video_path):
return None, None, None, "Video file not found"
# Reset state
self.points = []
self.frames = []
self.current_frame_idx = 0
self.state = None
# Read video frames
cap = cv2.VideoCapture(video_path)
while True:
ret, frame = cap.read()
if not ret:
break
self.frames.append(frame)
if not self.frames:
return None, None, None, "Failed to read video"
# Store video info
self.video_info = {
"path": video_path,
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
"fps": cap.get(cv2.CAP_PROP_FPS),
"total_frames": len(self.frames),
}
cap.release()
# Initialize SAM2 state
with self.autocast_context(), torch.inference_mode():
self.state = self.predictor.init_state(video_path)
return (
self.frames[0], # First frame
self.current_checkpoint,
gr.Slider(minimum=0, maximum=len(self.frames) - 1, step=1, value=0),
"Navigate through frames and click on the golf ball to track",
)
def update_frame(self, frame_idx):
"""Update displayed frame"""
if not self.frames or frame_idx >= len(self.frames):
return None
self.current_frame_idx = int(frame_idx)
frame = self.frames[self.current_frame_idx].copy()
# Draw existing points and trajectory
self._draw_tracking(frame)
return frame
def add_point(self, frame, evt: gr.SelectData):
"""Add a point and get ball prediction with enhanced mask visualization"""
if self.state is None:
return frame
x, y = evt.index[0], evt.index[1]
self.points.append((self.current_frame_idx, x, y))
frame_with_points = frame.copy()
# Get ball prediction using SAM2.1
with self.autocast_context(), torch.inference_mode():
# Convert points and labels to numpy arrays
points = np.array([(x, y)], dtype=np.float32)
labels = np.array([1], dtype=np.int32) # 1 for positive click
# Add point and get mask
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=self.state,
frame_idx=self.current_frame_idx,
obj_id=self.obj_id,
points=points,
labels=labels,
)
if out_mask_logits is not None and len(out_mask_logits) > 0:
self.out_mask_logits = out_mask_logits
# Draw tracking visualization
self._draw_tracking(frame_with_points)
return frame_with_points
def propagate_masks(self):
"""Propagate masks to the entire video after user selection"""
if self.state is None:
return "No state initialized"
logging.info(f"Propagating masks in video with state: {self.state}")
# Propagate the masks across the video
with self.autocast_context(), torch.inference_mode():
frame_idx, obj_ids, video_res_masks = self.predictor.propagate_in_video(
inference_state=self.state,
start_frame_idx=0,
reverse=False,
)
self.out_mask_logits = video_res_masks
return "Propagation complete"
def autocast_context(self):
if self.device.type == "cuda":
return torch.autocast("cuda", dtype=torch.bfloat16)
else:
return contextlib.nullcontext()
def _draw_tracking(self, frame):
"""Draw object mask on frame with enhanced visualization"""
# Assuming out_mask_logits is available from propagate_masks
if self.current_frame_idx < len(self.frames):
mask_np = (self.out_mask_logits[self.current_frame_idx] > 0.0).cpu().numpy()
if mask_np.shape[:2] == frame.shape[:2]:
overlay = frame.copy()
overlay[mask_np > 0] = [0, 0, 255] # Red color for mask
alpha = 0.5 # Transparency factor
frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
return frame
def clear_points(self):
"""Clear all tracked points"""
self.points = []
if self.frames:
return self.frames[self.current_frame_idx].copy()
return None
def change_model(self, checkpoint_name):
"""Change the current model checkpoint"""
if checkpoint_name != self.current_checkpoint:
self.load_model(checkpoint_name)
return f"Loaded {CHECKPOINTS[checkpoint_name]}"
def save_output_video(self):
"""Save the processed video with tracking visualization"""
if not self.frames or not self.video_info:
return None, "No video loaded"
output_path = "output_tracked.mp4"
# Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(
output_path,
fourcc,
self.video_info["fps"],
(self.video_info["width"], self.video_info["height"]),
)
# Process each frame
for frame_idx in range(len(self.frames)):
frame = self.frames[frame_idx].copy()
# Draw tracking for this frame
frame_points = [(x, y) for f, x, y in self.points if f == frame_idx]
if frame_points:
# Draw points
for x, y in frame_points:
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)
# Fit and draw trajectory if enough points
if len(frame_points) >= 3:
points_arr = np.array(frame_points)
# fit_results = self.trajectory_fitter.fit_trajectory(points_arr)
# if fit_results is not None:
# trajectory = fit_results["trajectory"]
# points = trajectory.astype(np.int32)
# for i in range(len(points) - 1):
# cv2.line(
# frame,
# tuple(points[i]),
# tuple(points[i + 1]),
# (0, 255, 0),
# 2,
# )
# # Calculate and display metrics
# metrics = self.trajectory_fitter.calculate_metrics(fit_results)
# cv2.putText(
# frame,
# f"Speed: {metrics['initial_velocity_mph']:.1f} mph",
# (10, 30),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1,
# (255, 255, 255),
# 2,
# )
# cv2.putText(
# frame,
# f"Height: {metrics['max_height']:.1f} m",
# (10, 70),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1,
# (255, 255, 255),
# 2,
# )
out.write(frame)
out.release()
return output_path, "Video saved successfully!"
def create_ui():
tracker = GolfTracker()
with gr.Blocks() as app:
gr.Markdown("# Golf Ball Trajectory Tracker")
gr.Markdown(
"Upload a video and click on the golf ball positions to track its trajectory"
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Input Video")
model_dropdown = gr.Dropdown(
choices=list(CHECKPOINTS.keys()),
value="SAM2-T",
label="Select Model",
)
upload_button = gr.Button("Process Video")
clear_button = gr.Button("Clear Points")
save_button = gr.Button("Save Output Video")
propagate_button = gr.Button("Propagate Masks")
with gr.Column():
image_output = gr.Image(label="Click on golf ball positions")
frame_slider = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label="Frame",
interactive=True,
)
current_model = gr.Textbox(label="Current Model", interactive=False)
status_text = gr.Textbox(label="Status", interactive=False)
output_video = gr.Video(label="Output Video")
# Event handlers
model_dropdown.change(
fn=tracker.change_model, inputs=[model_dropdown], outputs=[status_text]
)
video_input.change(
fn=tracker.process_video,
inputs=[video_input],
outputs=[image_output, current_model, frame_slider, status_text],
)
upload_button.click(
fn=tracker.process_video,
inputs=[video_input],
outputs=[image_output, current_model, frame_slider, status_text],
)
clear_button.click(fn=tracker.clear_points, inputs=[], outputs=[image_output])
frame_slider.change(
fn=tracker.update_frame, inputs=[frame_slider], outputs=[image_output]
)
image_output.select(
fn=tracker.add_point, inputs=[image_output], outputs=[image_output]
)
save_button.click(
fn=tracker.save_output_video, inputs=[], outputs=[output_video, status_text]
)
propagate_button.click(
fn=tracker.propagate_masks, inputs=[], outputs=[status_text]
)
return app
if __name__ == "__main__":
app = create_ui()
app.launch()