Fall-Detection / app.py
itsTomLie's picture
Update app.py
0db98b9 verified
import os
import numpy as np
import gradio as gr
import supervision as sv
from ultralytics import YOLO
# Define paths
HOME = os.getcwd()
MODEL_PATH = "./best.pt"
# Load the YOLO model
model = YOLO(MODEL_PATH)
# Initialize annotators
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
# Define the confidence threshold
CONFIDENCE_THRESHOLD = 0.6
# Define the callback function for processing each video frame
def callback(frame: np.ndarray, _: int) -> np.ndarray:
# Perform detection on the frame
results = model(frame)[0]
detections = sv.Detections.from_ultralytics(results)
# Filter detections based on confidence threshold
detections_filtered = detections[detections.confidence >
CONFIDENCE_THRESHOLD]
# Create labels for filtered detections
labels = [
f"{model.model.names[class_id]} {confidence:.2f}"
for class_id, confidence in zip(
detections_filtered.class_id, detections_filtered.confidence
)
]
# Annotate the frame with bounding boxes and labels
annotated_frame = box_annotator.annotate(
scene=frame.copy(),
detections=detections_filtered,
)
annotated_frame = label_annotator.annotate(
scene=annotated_frame,
detections=detections_filtered,
labels=labels
)
return annotated_frame
# Function to process the video and generate the output
def process_video_gradio(input_video):
SOURCE_VIDEO_PATH = input_video
TARGET_VIDEO_PATH = f"{HOME}/output_fall_detection.mp4"
sv.process_video(
source_path=SOURCE_VIDEO_PATH,
target_path=TARGET_VIDEO_PATH,
callback=callback
)
return TARGET_VIDEO_PATH
# Define the Gradio interface
interface = gr.Interface(
fn=process_video_gradio, # Function to process video
inputs=gr.Video(), # Upload video input
outputs=gr.Video(), # Return the annotated video output
title="Fall Detection Video Annotator",
description="Upload a video, and the model will annotate it with fall detection using Fine-Tuned YOLO model."
)
# Launch the interface
if __name__ == "__main__":
interface.launch()