import numpy as np from segment_anything import sam_model_registry, SamPredictor import cv2 from scipy.optimize import curve_fit class GolfTrajectoryPredictor: def __init__(self, sam_checkpoint): # Initialize SAM model self.sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint) self.predictor = SamPredictor(self.sam) def get_club_metrics(self, frame, point): """Extract golf club angle and position using SAM""" self.predictor.set_image(frame) masks, scores, _ = self.predictor.predict( point_coords=np.array([point]), point_labels=np.array([1]) ) # Get the best mask club_mask = masks[np.argmax(scores)] # Calculate club angle from mask coords = np.column_stack(np.where(club_mask)) if len(coords) < 2: return None, None # Fit line to get club angle vx, vy, x0, y0 = cv2.fitLine(coords, cv2.DIST_L2, 0, 0.01, 0.01) angle = np.arctan2(vy, vx) return angle, (x0[0], y0[0]) def physics_trajectory(self, t, v0, theta, h0, g=9.81): """Model the physics of projectile motion""" # Convert angle to radians theta_rad = np.radians(theta) # Initial velocities v0x = v0 * np.cos(theta_rad) v0y = v0 * np.sin(theta_rad) # Position equations x = v0x * t y = h0 + v0y * t - 0.5 * g * t**2 return np.column_stack((x, y)) def fit_trajectory(self, points, initial_height, club_angle=None): """Fit trajectory to user-selected points""" times = np.linspace(0, 1, len(points)) # Initial guess for parameters # Use club angle if available to better estimate initial velocity direction initial_theta = club_angle if club_angle is not None else 45 # Calculate approximate initial velocity from first two points if len(points) >= 2: dx = points[1][0] - points[0][0] dy = points[1][1] - points[0][1] initial_v0 = np.sqrt(dx**2 + dy**2) / (times[1] - times[0]) else: initial_v0 = 50 # Default initial guess # Fit physics model to points try: params, _ = curve_fit( lambda t, v0, theta: self.physics_trajectory(t, v0, theta, initial_height), times, points, p0=[initial_v0, initial_theta] ) return params except RuntimeError: return None def predict_full_trajectory(self, video_path, user_selected_points): """Main function to predict full trajectory""" cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) # Get initial frame for club analysis ret, first_frame = cap.read() if not ret: return None # Get club metrics from first frame club_angle, club_pos = self.get_club_metrics(first_frame, user_selected_points[0]) # Convert pixel coordinates to physical space # This requires camera calibration in practice scale_factor = 0.01 # meters per pixel physical_points = np.array(user_selected_points) * scale_factor # Fit trajectory params = self.fit_trajectory( physical_points, initial_height=physical_points[0][1], club_angle=club_angle ) if params is None: return None # Generate full trajectory t = np.linspace(0, len(user_selected_points)/fps, 100) full_trajectory = self.physics_trajectory(t, params[0], params[1], physical_points[0][1]) # Convert back to pixel coordinates pixel_trajectory = full_trajectory / scale_factor return pixel_trajectory, params def visualize_trajectory(self, frame, trajectory, user_points): """Visualize the predicted trajectory and user-selected points""" vis_frame = frame.copy() # Draw predicted trajectory trajectory = trajectory.astype(np.int32) for i in range(len(trajectory)-1): cv2.line(vis_frame, tuple(trajectory[i]), tuple(trajectory[i+1]), (0, 255, 0), 2) # Draw user-selected points for point in user_points: cv2.circle(vis_frame, tuple(map(int, point)), 5, (255, 0, 0), -1) return vis_frame