import gradio as gr
import torch
import numpy as np
import cv2
from sam2.build_sam import build_sam2_video_predictor
import tempfile
import os
import contextlib


class VideoTracker:
    def __init__(self):
        self.checkpoint = "./models/sam2.1_hiera_tiny.pt"
        self.model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
        self.predictor = build_sam2_video_predictor(
            self.model_cfg, self.checkpoint, device="cpu", mode="eval"
        )
        self.state = None
        self.video_frames = None
        self.current_frame_idx = 0
        self.masks = []
        self.points = []
        self.frame_count = 0
        self.video_info = None
        self.obj_id = 1
        self.out_mask_logits = None

    def load_video(self, video_path):
        if video_path is None:
            return None, gr.Slider(minimum=0, maximum=0, step=1, value=0)

        # Create a temporary file for the video
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        temp_file.close()

        # Copy the uploaded video to the temporary file
        with open(video_path, "rb") as f_src, open(temp_file.name, "wb") as f_dst:
            f_dst.write(f_src.read())

        # Load video frames using OpenCV
        cap = cv2.VideoCapture(temp_file.name)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)

        if not frames:
            cap.release()
            os.unlink(temp_file.name)
            return None, gr.Slider(minimum=0, maximum=0, step=1, value=0)

        # Store video info
        self.video_info = {
            "path": temp_file.name,
            "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
            "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            "fps": cap.get(cv2.CAP_PROP_FPS),
            "total_frames": len(frames),
        }
        cap.release()

        self.video_frames = frames
        self.frame_count = len(frames)

        # Initialize SAM2 state with video path
        with torch.inference_mode():
            self.state = self.predictor.init_state(temp_file.name)

        # Now we can remove the temp file
        os.unlink(temp_file.name)

        return frames[0], gr.Slider(minimum=0, maximum=len(frames) - 1, step=1, value=0)

    def update_frame(self, frame_number):
        if self.video_frames is None:
            return None

        self.current_frame_idx = frame_number
        frame = self.video_frames[frame_number].copy()

        # Draw existing points and masks
        if self.out_mask_logits is not None:
            frame = self._draw_tracking(frame)

        # Draw points
        for point in self.points:
            if point[0] == frame_number:
                cv2.circle(frame, (int(point[1]), int(point[2])), 5, (255, 0, 0), -1)

        return frame

    def add_point(self, frame, evt: gr.SelectData):
        """Add a point and get ball prediction with enhanced mask visualization"""
        if self.state is None:
            return frame

        x, y = evt.index[0], evt.index[1]
        self.points.append((self.current_frame_idx, x, y))

        frame_with_points = frame.copy()

        # Get ball prediction using SAM2.1
        with torch.inference_mode():
            # Convert points and labels to numpy arrays
            points = np.array([(x, y)], dtype=np.float32)
            labels = np.array([1], dtype=np.int32)  # 1 for positive click

            # Add point and get mask
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=self.state,
                frame_idx=self.current_frame_idx,
                obj_id=self.obj_id,
                points=points,
                labels=labels,
            )

            if out_mask_logits is not None and len(out_mask_logits) > 0:
                self.out_mask_logits = (
                    out_mask_logits[0]
                    if isinstance(out_mask_logits, list)
                    else out_mask_logits
                )

        # Draw tracking visualization with enhanced mask
        frame_with_points = self._draw_tracking(frame_with_points)

        # Draw point on top of mask
        cv2.circle(
            frame_with_points, (int(x), int(y)), 5, (255, 255, 0), -1
        )  # Yellow dot
        cv2.circle(frame_with_points, (int(x), int(y)), 7, (0, 0, 0), 1)  # Black border

        return frame_with_points

    def propagate_video(self):
        if self.state is None:
            return None

        output_frames = self.video_frames.copy()

        # Store all masks for smoother visualization
        all_masks = []

        # First pass: collect all masks
        with torch.inference_mode():
            for frame_idx, obj_ids, masks in self.predictor.propagate_in_video(
                self.state,
                start_frame_idx=0,
                reverse=False,
            ):
                if masks is not None and len(masks) > 0:
                    mask = masks[0] if isinstance(masks, list) else masks
                    all_masks.append((frame_idx, mask))

        # Second pass: apply visualization with temporal smoothing
        for i, frame in enumerate(output_frames):
            frame = frame.copy()

            # Find masks for this frame
            current_masks = [m[1] for m in all_masks if m[0] == i]

            if current_masks:
                self.out_mask_logits = current_masks[0]

                # Get binary mask and ensure correct dimensions
                mask_np = (current_masks[0] > 0.0).cpu().numpy()
                mask_np = self._handle_mask_dimensions(mask_np)

                # Convert to proper format for OpenCV
                mask_np = mask_np.astype(np.uint8)

                # Enhanced visualization for video
                frame = self._draw_tracking(frame, alpha=0.6)

                # Create glowing effect
                try:
                    # Create kernel for dilation
                    kernel = np.ones((5, 5), np.uint8)

                    # Dilate mask for glow effect
                    dilated_mask = cv2.dilate(mask_np, kernel, iterations=2)

                    # Create glow overlay
                    glow = frame.copy()
                    glow[dilated_mask > 0] = [0, 255, 255]  # Yellow glow

                    # Blend glow with frame
                    frame = cv2.addWeighted(frame, 0.7, glow, 0.3, 0)
                except cv2.error as e:
                    print(
                        f"Warning: Could not apply glow effect. Mask shape: {mask_np.shape}, Frame shape: {frame.shape}"
                    )
                    # Continue without glow effect if there's an error

            output_frames[i] = frame

        # Save as video with higher quality
        temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
        height, width = output_frames[0].shape[:2]

        # Use higher bitrate for better quality
        writer = cv2.VideoWriter(
            temp_output, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height), True
        )

        for frame in output_frames:
            writer.write(frame)
        writer.release()

        return temp_output

    def _handle_mask_dimensions(self, mask_np):
        """Helper function to handle various mask dimensions"""
        # Handle 4D tensor (1, 1, H, W)
        if len(mask_np.shape) == 4:
            mask_np = mask_np[0, 0]
        # Handle 3D tensor (1, H, W) or (H, W, 1)
        elif len(mask_np.shape) == 3:
            if mask_np.shape[0] == 1:  # (1, H, W) format
                mask_np = mask_np[0]
            elif mask_np.shape[2] == 1:  # (H, W, 1) format
                mask_np = mask_np[:, :, 0]
        return mask_np

    def _draw_tracking(self, frame, alpha=0.5):
        """Draw object mask on frame with enhanced visualization"""
        if self.out_mask_logits is not None:
            # Convert logits to binary mask
            if isinstance(self.out_mask_logits, list):
                mask = self.out_mask_logits[0]
            else:
                mask = self.out_mask_logits

            # Get binary mask and handle dimensions
            mask_np = (mask > 0.0).cpu().numpy()
            mask_np = self._handle_mask_dimensions(mask_np)

            if mask_np.shape[:2] == frame.shape[:2]:
                # Create a red overlay for the mask
                overlay = frame.copy()
                overlay[mask_np > 0] = [0, 0, 255]  # BGR format: Red color

                # Add a border around the mask for better visibility
                contours, _ = cv2.findContours(
                    mask_np.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
                )

                # Draw thicker contours for better visibility
                cv2.drawContours(
                    overlay, contours, -1, (0, 255, 255), 3
                )  # Thicker yellow border

                # Add a second contour for emphasis
                cv2.drawContours(
                    frame, contours, -1, (255, 255, 0), 1
                )  # Thin bright border

                # Blend the overlay with original frame
                frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)

        return frame


def create_interface():
    tracker = VideoTracker()

    with gr.Blocks() as interface:
        gr.Markdown("# Object Tracking with SAM2")
        gr.Markdown("Upload a video and click on objects to track them")

        with gr.Row():
            with gr.Column(scale=2):
                video_input = gr.Video(label="Input Video")
                image_output = gr.Image(label="Current Frame", interactive=True)
                frame_slider = gr.Slider(
                    minimum=0,
                    maximum=0,
                    step=1,
                    value=0,
                    label="Frame Selection",
                    interactive=True,
                )

            with gr.Column(scale=1):
                propagate_btn = gr.Button("Propagate Through Video", variant="primary")
                video_output = gr.Video(label="Output Video")

        video_input.change(
            fn=tracker.load_video,
            inputs=[video_input],
            outputs=[image_output, frame_slider],
        )

        frame_slider.change(
            fn=tracker.update_frame,
            inputs=[frame_slider],
            outputs=[image_output],
        )

        image_output.select(
            fn=tracker.add_point,
            inputs=[image_output],
            outputs=[image_output],
        )

        propagate_btn.click(
            fn=tracker.propagate_video,
            inputs=[],
            outputs=[video_output],
        )

    return interface


if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)