InsectSpy / app.py
ElodieA's picture
Update app.py
ed2bad4 verified
import cv2
import tempfile
import gradio as gr
from ultralytics import YOLO
import pandas as pd
import plotly.graph_objects as go
import numpy as np
# Define the label mapping
label_mapping = {
0: 'Hymenoptera',
1: 'Mantodea',
2: 'Odonata',
3: 'Orthoptera',
4: 'Coleoptera',
5: 'Lepidoptera',
6: 'Hemiptera'
}
def process_video(video_file):
# Load the YOLOv8 model
model = YOLO("insect_detection4.pt")
# Open the video file
cap = cv2.VideoCapture(video_file)
# Prepare DataFrame for storing detection data
columns = ["frame", "insect_id", "class", "x", "y", "w", "h"]
df = pd.DataFrame(columns=columns)
frame_id = 0
unique_insect_crops = {}
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
frame_id += 1
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True, tracker="insect_tracker.yaml")
for result in results:
boxes = result.boxes.cpu().numpy()
confidences = boxes.conf
class_ids = boxes.cls
for i, box in enumerate(boxes):
class_id = int(class_ids[i])
confidence = confidences[i]
insect_id = int(box.id[0]) if box.id is not None else -1 # Use -1 if ID is not available
# Append detection data to DataFrame
new_row = pd.DataFrame({
"frame": [frame_id],
"insect_id": [insect_id],
"class": [class_id],
"x": [box.xywh[0][0]],
"y": [box.xywh[0][1]],
"w": [box.xywh[0][2]],
"h": [box.xywh[0][3]]
})
df = pd.concat([df, new_row], ignore_index=True)
# Crop and save the image of the insect
if insect_id not in unique_insect_crops:
x_center, y_center, width, height = box.xywh[0]
x1 = int(x_center - width / 2)
y1 = int(y_center - height / 2)
x2 = int(x_center + width / 2)
y2 = int(y_center + height / 2)
insect_crop = frame[y1:y2, x1:x2]
crop_path = tempfile.mktemp(suffix=".png")
cv2.imwrite(crop_path, insect_crop)
unique_insect_crops[insect_id] = (crop_path, label_mapping[class_id])
else:
break
# Release the video capture
cap.release()
# Save DataFrame to CSV
csv_path = tempfile.mktemp(suffix=".csv")
df.to_csv(csv_path, index=False)
# Read the DataFrame from the CSV file
df_from_csv = pd.read_csv(csv_path)
# Create the interactive plot from the CSV data
fig = go.Figure()
for insect_id, group in df_from_csv.groupby('insect_id'):
class_name = label_mapping[group.iloc[0]['class']]
color = 'rgb({}, {}, {})'.format(*np.random.randint(0, 256, 3))
hover_text = group.apply(lambda row: f'Insect ID: {int(row["insect_id"])}, Class: {class_name}, Frame: {int(row["frame"])}', axis=1)
fig.add_trace(go.Scatter(x=group['frame'], y=group['insect_id'], mode='markers', marker=dict(color=color), name=f'{class_name} {insect_id}',
hoverinfo='text', hovertext=hover_text))
fig.update_layout(title='Temporal distribution of insects',
xaxis_title='Frame',
yaxis_title='Insect ID',
hovermode='closest')
gallery_items = [(crop_path, f'{label} {insect_id}') for insect_id, (crop_path, label) in unique_insect_crops.items()]
return fig, gallery_items, csv_path
# Create a Gradio interface
example_video = "insect_trap_video_example.mp4" # Replace with the actual path to your example video
inputs = gr.Video(label="Input Insect Trap Video", value=example_video)
outputs = [
gr.Plot(label="Insect Detection Plot"),
gr.Gallery(label="Insect Gallery"), # Added a gallery to display insect crops with labels
gr.File(label="Download CSV")
]
description = """
Uncover the Secret Lives of Insects of the Amazonian Forest! πŸπŸ¦‹πŸ•·οΈ
Upload your video now to track, visualize, and explore insect activity with our cutting-edge detection tool. You can get started with the example video.
"""
gr.Interface(fn=process_video, inputs=inputs, outputs=outputs, title= 'InsectSpy πŸ•΅οΈβ€β™‚οΈπŸ¦—', description=description, examples=[example_video]).launch()