import gradio as gr
import numpy as np
import cv2
import torch
import os
import logging
import contextlib
from sam2.build_sam import build_sam2_video_predictor

# Add current directory to path
import sys

sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), "sam2"))  # Add sam2 directory to path
print(f"current dir is {os.getcwd()}")

# Ensure device setup matches the official code
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
if force_cpu_device:
    logging.info("forcing CPU device for SAM 2 demo")
if torch.cuda.is_available() and not force_cpu_device:
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available() and not force_cpu_device:
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
logging.info(f"using device: {DEVICE}")

if DEVICE.type == "cuda":
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif DEVICE.type == "mps":
    logging.warning(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )


def load_model_paths(checkpoint_name):
    """Get model checkpoint and config paths"""
    if checkpoint_name == "SAM2-T":
        sam2_checkpoint = "models/sam2.1_hiera_tiny.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
    elif checkpoint_name == "SAM2-S":
        sam2_checkpoint = "models/sam2.1_hiera_small.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
    elif checkpoint_name == "SAM2-B_PLUS":
        sam2_checkpoint = "models/sam2.1_hiera_base_plus.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
    else:
        raise ValueError(f"Invalid checkpoint name: {checkpoint_name}")

    return sam2_checkpoint, model_cfg


# Available checkpoints
CHECKPOINTS = {
    "SAM2-B_PLUS": "Base Plus Model",
    "SAM2-S": "Small Model",
    "SAM2-T": "Tiny Model",
}


class GolfTracker:
    def __init__(self, checkpoint="SAM2-T"):
        """Initialize with specified checkpoint model"""
        self.current_checkpoint = checkpoint
        self.predictor = None
        self.points = []
        self.frames = []
        self.current_frame_idx = 0
        self.video_info = None
        self.state = None
        self.obj_id = 1  # Track single object (golf ball)
        self.device = DEVICE
        self.out_mask_logits = None
        self.load_model(checkpoint)

    def load_model(self, checkpoint_name):
        """Load specified checkpoint model"""
        if checkpoint_name not in CHECKPOINTS:
            raise ValueError(f"Invalid checkpoint: {checkpoint_name}")

        print(f"Loading checkpoint: {checkpoint_name}")
        sam2_checkpoint, model_cfg = load_model_paths(checkpoint_name)

        # Build predictor with model config and checkpoint
        self.predictor = build_sam2_video_predictor(
            model_cfg, sam2_checkpoint, self.device
        )
        print(f"Model loaded successfully: {CHECKPOINTS[checkpoint_name]}")
        self.current_checkpoint = checkpoint_name

    def process_video(self, video_path):
        """Process the video and initialize tracking"""
        if not os.path.exists(video_path):
            return None, None, None, "Video file not found"

        # Reset state
        self.points = []
        self.frames = []
        self.current_frame_idx = 0
        self.state = None

        # Read video frames
        cap = cv2.VideoCapture(video_path)
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            self.frames.append(frame)

        if not self.frames:
            return None, None, None, "Failed to read video"

        # Store video info
        self.video_info = {
            "path": video_path,
            "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
            "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            "fps": cap.get(cv2.CAP_PROP_FPS),
            "total_frames": len(self.frames),
        }

        cap.release()

        # Initialize SAM2 state
        with self.autocast_context(), torch.inference_mode():
            self.state = self.predictor.init_state(video_path)

        return (
            self.frames[0],  # First frame
            self.current_checkpoint,
            gr.Slider(minimum=0, maximum=len(self.frames) - 1, step=1, value=0),
            "Navigate through frames and click on the golf ball to track",
        )

    def update_frame(self, frame_idx):
        """Update displayed frame"""
        if not self.frames or frame_idx >= len(self.frames):
            return None

        self.current_frame_idx = int(frame_idx)
        frame = self.frames[self.current_frame_idx].copy()

        # Draw existing points and trajectory
        self._draw_tracking(frame)
        return frame

    def add_point(self, frame, evt: gr.SelectData):
        """Add a point and get ball prediction with enhanced mask visualization"""
        if self.state is None:
            return frame

        x, y = evt.index[0], evt.index[1]
        self.points.append((self.current_frame_idx, x, y))

        frame_with_points = frame.copy()

        # Get ball prediction using SAM2.1
        with self.autocast_context(), torch.inference_mode():
            # Convert points and labels to numpy arrays
            points = np.array([(x, y)], dtype=np.float32)
            labels = np.array([1], dtype=np.int32)  # 1 for positive click

            # Add point and get mask
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=self.state,
                frame_idx=self.current_frame_idx,
                obj_id=self.obj_id,
                points=points,
                labels=labels,
            )

            if out_mask_logits is not None and len(out_mask_logits) > 0:
                self.out_mask_logits = out_mask_logits

        # Draw tracking visualization
        self._draw_tracking(frame_with_points)
        return frame_with_points

    def propagate_masks(self):
        """Propagate masks to the entire video after user selection"""
        if self.state is None:
            return "No state initialized"

        logging.info(f"Propagating masks in video with state: {self.state}")

        # Propagate the masks across the video
        with self.autocast_context(), torch.inference_mode():
            frame_idx, obj_ids, video_res_masks = self.predictor.propagate_in_video(
                inference_state=self.state,
                start_frame_idx=0,
                reverse=False,
            )

            self.out_mask_logits = video_res_masks

        return "Propagation complete"

    def autocast_context(self):
        if self.device.type == "cuda":
            return torch.autocast("cuda", dtype=torch.bfloat16)
        else:
            return contextlib.nullcontext()

    def _draw_tracking(self, frame):
        """Draw object mask on frame with enhanced visualization"""
        # Assuming out_mask_logits is available from propagate_masks
        if self.current_frame_idx < len(self.frames):
            mask_np = (self.out_mask_logits[self.current_frame_idx] > 0.0).cpu().numpy()
            if mask_np.shape[:2] == frame.shape[:2]:
                overlay = frame.copy()
                overlay[mask_np > 0] = [0, 0, 255]  # Red color for mask
                alpha = 0.5  # Transparency factor
                frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
        return frame

    def clear_points(self):
        """Clear all tracked points"""
        self.points = []
        if self.frames:
            return self.frames[self.current_frame_idx].copy()
        return None

    def change_model(self, checkpoint_name):
        """Change the current model checkpoint"""
        if checkpoint_name != self.current_checkpoint:
            self.load_model(checkpoint_name)
        return f"Loaded {CHECKPOINTS[checkpoint_name]}"

    def save_output_video(self):
        """Save the processed video with tracking visualization"""
        if not self.frames or not self.video_info:
            return None, "No video loaded"

        output_path = "output_tracked.mp4"

        # Initialize video writer
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(
            output_path,
            fourcc,
            self.video_info["fps"],
            (self.video_info["width"], self.video_info["height"]),
        )

        # Process each frame
        for frame_idx in range(len(self.frames)):
            frame = self.frames[frame_idx].copy()

            # Draw tracking for this frame
            frame_points = [(x, y) for f, x, y in self.points if f == frame_idx]
            if frame_points:
                # Draw points
                for x, y in frame_points:
                    cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)

                # Fit and draw trajectory if enough points
                if len(frame_points) >= 3:
                    points_arr = np.array(frame_points)
                    # fit_results = self.trajectory_fitter.fit_trajectory(points_arr)

                    # if fit_results is not None:
                    #     trajectory = fit_results["trajectory"]
                    #     points = trajectory.astype(np.int32)
                    #     for i in range(len(points) - 1):
                    #         cv2.line(
                    #             frame,
                    #             tuple(points[i]),
                    #             tuple(points[i + 1]),
                    #             (0, 255, 0),
                    #             2,
                    #         )

                    #     # Calculate and display metrics
                    #     metrics = self.trajectory_fitter.calculate_metrics(fit_results)
                    #     cv2.putText(
                    #         frame,
                    #         f"Speed: {metrics['initial_velocity_mph']:.1f} mph",
                    #         (10, 30),
                    #         cv2.FONT_HERSHEY_SIMPLEX,
                    #         1,
                    #         (255, 255, 255),
                    #         2,
                    #     )
                    #     cv2.putText(
                    #         frame,
                    #         f"Height: {metrics['max_height']:.1f} m",
                    #         (10, 70),
                    #         cv2.FONT_HERSHEY_SIMPLEX,
                    #         1,
                    #         (255, 255, 255),
                    #         2,
                    #     )

            out.write(frame)

        out.release()
        return output_path, "Video saved successfully!"


def create_ui():
    tracker = GolfTracker()

    with gr.Blocks() as app:
        gr.Markdown("# Golf Ball Trajectory Tracker")
        gr.Markdown(
            "Upload a video and click on the golf ball positions to track its trajectory"
        )

        with gr.Row():
            with gr.Column():
                video_input = gr.Video(label="Input Video")
                model_dropdown = gr.Dropdown(
                    choices=list(CHECKPOINTS.keys()),
                    value="SAM2-T",
                    label="Select Model",
                )
                upload_button = gr.Button("Process Video")
                clear_button = gr.Button("Clear Points")
                save_button = gr.Button("Save Output Video")
                propagate_button = gr.Button("Propagate Masks")

            with gr.Column():
                image_output = gr.Image(label="Click on golf ball positions")
                frame_slider = gr.Slider(
                    minimum=0,
                    maximum=0,
                    step=1,
                    value=0,
                    label="Frame",
                    interactive=True,
                )
                current_model = gr.Textbox(label="Current Model", interactive=False)
                status_text = gr.Textbox(label="Status", interactive=False)
                output_video = gr.Video(label="Output Video")

        # Event handlers
        model_dropdown.change(
            fn=tracker.change_model, inputs=[model_dropdown], outputs=[status_text]
        )

        video_input.change(
            fn=tracker.process_video,
            inputs=[video_input],
            outputs=[image_output, current_model, frame_slider, status_text],
        )

        upload_button.click(
            fn=tracker.process_video,
            inputs=[video_input],
            outputs=[image_output, current_model, frame_slider, status_text],
        )

        clear_button.click(fn=tracker.clear_points, inputs=[], outputs=[image_output])

        frame_slider.change(
            fn=tracker.update_frame, inputs=[frame_slider], outputs=[image_output]
        )

        image_output.select(
            fn=tracker.add_point, inputs=[image_output], outputs=[image_output]
        )

        save_button.click(
            fn=tracker.save_output_video, inputs=[], outputs=[output_video, status_text]
        )

        propagate_button.click(
            fn=tracker.propagate_masks, inputs=[], outputs=[status_text]
        )

    return app


if __name__ == "__main__":
    app = create_ui()
    app.launch()