Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
import os | |
import logging | |
import contextlib | |
from sam2.build_sam import build_sam2_video_predictor | |
# Add current directory to path | |
import sys | |
sys.path.append(os.getcwd()) | |
sys.path.append(os.path.join(os.getcwd(), "sam2")) # Add sam2 directory to path | |
print(f"current dir is {os.getcwd()}") | |
# Ensure device setup matches the official code | |
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1" | |
if force_cpu_device: | |
logging.info("forcing CPU device for SAM 2 demo") | |
if torch.cuda.is_available() and not force_cpu_device: | |
DEVICE = torch.device("cuda") | |
elif torch.backends.mps.is_available() and not force_cpu_device: | |
DEVICE = torch.device("mps") | |
else: | |
DEVICE = torch.device("cpu") | |
logging.info(f"using device: {DEVICE}") | |
if DEVICE.type == "cuda": | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
elif DEVICE.type == "mps": | |
logging.warning( | |
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " | |
"give numerically different outputs and sometimes degraded performance on MPS. " | |
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." | |
) | |
def load_model_paths(checkpoint_name): | |
"""Get model checkpoint and config paths""" | |
if checkpoint_name == "SAM2-T": | |
sam2_checkpoint = "models/sam2.1_hiera_tiny.pt" | |
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
elif checkpoint_name == "SAM2-S": | |
sam2_checkpoint = "models/sam2.1_hiera_small.pt" | |
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" | |
elif checkpoint_name == "SAM2-B_PLUS": | |
sam2_checkpoint = "models/sam2.1_hiera_base_plus.pt" | |
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" | |
else: | |
raise ValueError(f"Invalid checkpoint name: {checkpoint_name}") | |
return sam2_checkpoint, model_cfg | |
# Available checkpoints | |
CHECKPOINTS = { | |
"SAM2-B_PLUS": "Base Plus Model", | |
"SAM2-S": "Small Model", | |
"SAM2-T": "Tiny Model", | |
} | |
class GolfTracker: | |
def __init__(self, checkpoint="SAM2-T"): | |
"""Initialize with specified checkpoint model""" | |
self.current_checkpoint = checkpoint | |
self.predictor = None | |
self.points = [] | |
self.frames = [] | |
self.current_frame_idx = 0 | |
self.video_info = None | |
self.state = None | |
self.obj_id = 1 # Track single object (golf ball) | |
self.device = DEVICE | |
self.out_mask_logits = None | |
self.load_model(checkpoint) | |
def load_model(self, checkpoint_name): | |
"""Load specified checkpoint model""" | |
if checkpoint_name not in CHECKPOINTS: | |
raise ValueError(f"Invalid checkpoint: {checkpoint_name}") | |
print(f"Loading checkpoint: {checkpoint_name}") | |
sam2_checkpoint, model_cfg = load_model_paths(checkpoint_name) | |
# Build predictor with model config and checkpoint | |
self.predictor = build_sam2_video_predictor( | |
model_cfg, sam2_checkpoint, self.device | |
) | |
print(f"Model loaded successfully: {CHECKPOINTS[checkpoint_name]}") | |
self.current_checkpoint = checkpoint_name | |
def process_video(self, video_path): | |
"""Process the video and initialize tracking""" | |
if not os.path.exists(video_path): | |
return None, None, None, "Video file not found" | |
# Reset state | |
self.points = [] | |
self.frames = [] | |
self.current_frame_idx = 0 | |
self.state = None | |
# Read video frames | |
cap = cv2.VideoCapture(video_path) | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
self.frames.append(frame) | |
if not self.frames: | |
return None, None, None, "Failed to read video" | |
# Store video info | |
self.video_info = { | |
"path": video_path, | |
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), | |
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
"fps": cap.get(cv2.CAP_PROP_FPS), | |
"total_frames": len(self.frames), | |
} | |
cap.release() | |
# Initialize SAM2 state | |
with self.autocast_context(), torch.inference_mode(): | |
self.state = self.predictor.init_state(video_path) | |
return ( | |
self.frames[0], # First frame | |
self.current_checkpoint, | |
gr.Slider(minimum=0, maximum=len(self.frames) - 1, step=1, value=0), | |
"Navigate through frames and click on the golf ball to track", | |
) | |
def update_frame(self, frame_idx): | |
"""Update displayed frame""" | |
if not self.frames or frame_idx >= len(self.frames): | |
return None | |
self.current_frame_idx = int(frame_idx) | |
frame = self.frames[self.current_frame_idx].copy() | |
# Draw existing points and trajectory | |
self._draw_tracking(frame) | |
return frame | |
def add_point(self, frame, evt: gr.SelectData): | |
"""Add a point and get ball prediction with enhanced mask visualization""" | |
if self.state is None: | |
return frame | |
x, y = evt.index[0], evt.index[1] | |
self.points.append((self.current_frame_idx, x, y)) | |
frame_with_points = frame.copy() | |
# Get ball prediction using SAM2.1 | |
with self.autocast_context(), torch.inference_mode(): | |
# Convert points and labels to numpy arrays | |
points = np.array([(x, y)], dtype=np.float32) | |
labels = np.array([1], dtype=np.int32) # 1 for positive click | |
# Add point and get mask | |
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points( | |
inference_state=self.state, | |
frame_idx=self.current_frame_idx, | |
obj_id=self.obj_id, | |
points=points, | |
labels=labels, | |
) | |
if out_mask_logits is not None and len(out_mask_logits) > 0: | |
self.out_mask_logits = out_mask_logits | |
# Draw tracking visualization | |
self._draw_tracking(frame_with_points) | |
return frame_with_points | |
def propagate_masks(self): | |
"""Propagate masks to the entire video after user selection""" | |
if self.state is None: | |
return "No state initialized" | |
logging.info(f"Propagating masks in video with state: {self.state}") | |
# Propagate the masks across the video | |
with self.autocast_context(), torch.inference_mode(): | |
frame_idx, obj_ids, video_res_masks = self.predictor.propagate_in_video( | |
inference_state=self.state, | |
start_frame_idx=0, | |
reverse=False, | |
) | |
self.out_mask_logits = video_res_masks | |
return "Propagation complete" | |
def autocast_context(self): | |
if self.device.type == "cuda": | |
return torch.autocast("cuda", dtype=torch.bfloat16) | |
else: | |
return contextlib.nullcontext() | |
def _draw_tracking(self, frame): | |
"""Draw object mask on frame with enhanced visualization""" | |
# Assuming out_mask_logits is available from propagate_masks | |
if self.current_frame_idx < len(self.frames): | |
mask_np = (self.out_mask_logits[self.current_frame_idx] > 0.0).cpu().numpy() | |
if mask_np.shape[:2] == frame.shape[:2]: | |
overlay = frame.copy() | |
overlay[mask_np > 0] = [0, 0, 255] # Red color for mask | |
alpha = 0.5 # Transparency factor | |
frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) | |
return frame | |
def clear_points(self): | |
"""Clear all tracked points""" | |
self.points = [] | |
if self.frames: | |
return self.frames[self.current_frame_idx].copy() | |
return None | |
def change_model(self, checkpoint_name): | |
"""Change the current model checkpoint""" | |
if checkpoint_name != self.current_checkpoint: | |
self.load_model(checkpoint_name) | |
return f"Loaded {CHECKPOINTS[checkpoint_name]}" | |
def save_output_video(self): | |
"""Save the processed video with tracking visualization""" | |
if not self.frames or not self.video_info: | |
return None, "No video loaded" | |
output_path = "output_tracked.mp4" | |
# Initialize video writer | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter( | |
output_path, | |
fourcc, | |
self.video_info["fps"], | |
(self.video_info["width"], self.video_info["height"]), | |
) | |
# Process each frame | |
for frame_idx in range(len(self.frames)): | |
frame = self.frames[frame_idx].copy() | |
# Draw tracking for this frame | |
frame_points = [(x, y) for f, x, y in self.points if f == frame_idx] | |
if frame_points: | |
# Draw points | |
for x, y in frame_points: | |
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) | |
# Fit and draw trajectory if enough points | |
if len(frame_points) >= 3: | |
points_arr = np.array(frame_points) | |
# fit_results = self.trajectory_fitter.fit_trajectory(points_arr) | |
# if fit_results is not None: | |
# trajectory = fit_results["trajectory"] | |
# points = trajectory.astype(np.int32) | |
# for i in range(len(points) - 1): | |
# cv2.line( | |
# frame, | |
# tuple(points[i]), | |
# tuple(points[i + 1]), | |
# (0, 255, 0), | |
# 2, | |
# ) | |
# # Calculate and display metrics | |
# metrics = self.trajectory_fitter.calculate_metrics(fit_results) | |
# cv2.putText( | |
# frame, | |
# f"Speed: {metrics['initial_velocity_mph']:.1f} mph", | |
# (10, 30), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# 1, | |
# (255, 255, 255), | |
# 2, | |
# ) | |
# cv2.putText( | |
# frame, | |
# f"Height: {metrics['max_height']:.1f} m", | |
# (10, 70), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# 1, | |
# (255, 255, 255), | |
# 2, | |
# ) | |
out.write(frame) | |
out.release() | |
return output_path, "Video saved successfully!" | |
def create_ui(): | |
tracker = GolfTracker() | |
with gr.Blocks() as app: | |
gr.Markdown("# Golf Ball Trajectory Tracker") | |
gr.Markdown( | |
"Upload a video and click on the golf ball positions to track its trajectory" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Input Video") | |
model_dropdown = gr.Dropdown( | |
choices=list(CHECKPOINTS.keys()), | |
value="SAM2-T", | |
label="Select Model", | |
) | |
upload_button = gr.Button("Process Video") | |
clear_button = gr.Button("Clear Points") | |
save_button = gr.Button("Save Output Video") | |
propagate_button = gr.Button("Propagate Masks") | |
with gr.Column(): | |
image_output = gr.Image(label="Click on golf ball positions") | |
frame_slider = gr.Slider( | |
minimum=0, | |
maximum=0, | |
step=1, | |
value=0, | |
label="Frame", | |
interactive=True, | |
) | |
current_model = gr.Textbox(label="Current Model", interactive=False) | |
status_text = gr.Textbox(label="Status", interactive=False) | |
output_video = gr.Video(label="Output Video") | |
# Event handlers | |
model_dropdown.change( | |
fn=tracker.change_model, inputs=[model_dropdown], outputs=[status_text] | |
) | |
video_input.change( | |
fn=tracker.process_video, | |
inputs=[video_input], | |
outputs=[image_output, current_model, frame_slider, status_text], | |
) | |
upload_button.click( | |
fn=tracker.process_video, | |
inputs=[video_input], | |
outputs=[image_output, current_model, frame_slider, status_text], | |
) | |
clear_button.click(fn=tracker.clear_points, inputs=[], outputs=[image_output]) | |
frame_slider.change( | |
fn=tracker.update_frame, inputs=[frame_slider], outputs=[image_output] | |
) | |
image_output.select( | |
fn=tracker.add_point, inputs=[image_output], outputs=[image_output] | |
) | |
save_button.click( | |
fn=tracker.save_output_video, inputs=[], outputs=[output_video, status_text] | |
) | |
propagate_button.click( | |
fn=tracker.propagate_masks, inputs=[], outputs=[status_text] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_ui() | |
app.launch() | |