InsectSpy / app.py
ElodieA's picture
Update app.py
4dff9c4 verified
raw
history blame
4.67 kB
import cv2
import tempfile
import gradio as gr
from ultralytics import YOLO
import pandas as pd
import matplotlib.pyplot as plt
# Define the label mapping
label_mapping = {
0: 'Hymenoptera',
1: 'Mantodea',
2: 'Odonata',
3: 'Orthoptera',
4: 'Coleoptera',
5: 'Lepidoptera',
6: 'Hemiptera'
}
def process_video(video_file):
# Define colors for each class (8 classes)
colors = [
(255, 0, 0), # Class 0 - Blue
(50, 205, 50), # Class 1 - Green
(0, 0, 255), # Class 2 - Red
(255, 255, 0), # Class 3 - Cyan
(255, 0, 255), # Class 4 - Magenta
(255, 140, 0), # Class 5 - Orange
(128, 0, 128), # Class 6 - Purple
(0, 128, 128) # Class 7 - Teal
]
# Load the YOLOv8 model
model = YOLO("insect_detection4.pt")
# Open the video file
cap = cv2.VideoCapture(video_file)
# Prepare DataFrame for storing detection data
columns = ["frame", "insect_id", "class", "x", "y", "w", "h"]
df = pd.DataFrame(columns=columns)
frame_id = 0
unique_insect_crops = {}
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
frame_id += 1
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True, tracker="insect_tracker.yaml")
for result in results:
boxes = result.boxes.cpu().numpy()
confidences = boxes.conf
class_ids = boxes.cls
for i, box in enumerate(boxes):
class_id = int(class_ids[i])
confidence = confidences[i]
insect_id = int(box.id[0]) # Ensure the ID is an integer
# Append detection data to DataFrame
new_row = pd.DataFrame({
"frame": [frame_id],
"insect_id": [insect_id],
"class": [class_id],
"x": [box.xywh[0][0]],
"y": [box.xywh[0][1]],
"w": [box.xywh[0][2]],
"h": [box.xywh[0][3]]
})
df = pd.concat([df, new_row], ignore_index=True)
# Crop and save the image of the insect
if insect_id not in unique_insect_crops:
x_center, y_center, width, height = box.xywh[0]
x1 = int(x_center - width / 2)
y1 = int(y_center - height / 2)
x2 = int(x_center + width / 2)
y2 = int(y_center + height / 2)
insect_crop = frame[y1:y2, x1:x2]
crop_path = tempfile.mktemp(suffix=".png")
cv2.imwrite(crop_path, insect_crop)
unique_insect_crops[insect_id] = (crop_path, label_mapping[class_id])
else:
break
# Release the video capture
cap.release()
# Save DataFrame to CSV
csv_path = tempfile.mktemp(suffix=".csv")
df.to_csv(csv_path, index=False)
# Read the DataFrame from the CSV file
df_from_csv = pd.read_csv(csv_path)
# Create the plot from the CSV data
plt.figure(figsize=(10, 6))
for insect_id in df_from_csv['insect_id'].unique():
frames = df_from_csv[df_from_csv['insect_id'] == insect_id]['frame']
insect_class = label_mapping[df_from_csv[df_from_csv['insect_id'] == insect_id]['class'].values[0]]
plt.plot(frames, [insect_id] * len(frames), 'o-', label=f'{insect_class} {insect_id}')
plt.xlabel('Frame')
plt.ylabel('Insect ID')
plt.title('Temporal distribution of insects')
plt.legend()
plt.grid(True)
# Save the plot to a temporary file
plot_path = tempfile.mktemp(suffix=".png")
plt.savefig(plot_path)
plt.close()
gallery_items = [(crop_path, f'{label} {insect_id}') for insect_id, (crop_path, label) in unique_insect_crops.items()]
return plot_path, gallery_items, csv_path
# Create a Gradio interface
example_video = "path_to_example_video.mp4" # Replace with the actual path to your example video
inputs = gr.Video(label="Input Insect Trap Video", value=example_video)
outputs = [
gr.Image(label="Insect Detection Plot"),
gr.Gallery(label="Unique Insect Crops"), # Added a gallery to display insect crops with labels
gr.File(label="Download CSV")
]
gr.Interface(fn=process_video, inputs=inputs, outputs=outputs, examples=[example_video]).launch()