File size: 4,669 Bytes
c786020
 
6a0ef93
c786020
0f6c01e
d4d3a7d
c786020
4dff9c4
 
 
 
 
 
 
 
 
 
 
c786020
6a0ef93
 
 
 
 
 
 
 
 
 
 
 
 
9932847
6a0ef93
c786020
c48dbc5
c786020
0f6c01e
4dff9c4
0f6c01e
6a0ef93
0f6c01e
d4d3a7d
6a0ef93
0f6c01e
 
 
 
6a0ef93
0f6c01e
 
6a0ef93
0f6c01e
 
c786020
0f6c01e
 
 
 
c786020
0f6c01e
 
 
4dff9c4
c786020
0f6c01e
 
 
4dff9c4
 
0f6c01e
 
 
 
 
 
6a0ef93
d4d3a7d
4dff9c4
d4d3a7d
 
 
 
 
 
 
 
4dff9c4
d4d3a7d
0f6c01e
 
6a0ef93
 
c786020
6a0ef93
0f6c01e
 
 
 
d4d3a7d
 
 
 
 
4dff9c4
 
 
 
d4d3a7d
 
 
 
 
 
 
 
 
 
 
 
4dff9c4
d4d3a7d
 
6a0ef93
 
4dff9c4
 
 
0f6c01e
d4d3a7d
 
0f6c01e
 
c786020
4dff9c4
20d4bc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import cv2
import tempfile
import gradio as gr
from ultralytics import YOLO
import pandas as pd
import matplotlib.pyplot as plt

# Define the label mapping
label_mapping = {
    0: 'Hymenoptera',
    1: 'Mantodea',
    2: 'Odonata',
    3: 'Orthoptera',
    4: 'Coleoptera',
    5: 'Lepidoptera',
    6: 'Hemiptera'
}

def process_video(video_file):
    # Define colors for each class (8 classes)
    colors = [
        (255, 0, 0),   # Class 0 - Blue
        (50, 205, 50),   # Class 1 - Green
        (0, 0, 255),   # Class 2 - Red
        (255, 255, 0), # Class 3 - Cyan
        (255, 0, 255), # Class 4 - Magenta
        (255, 140, 0), # Class 5 - Orange
        (128, 0, 128), # Class 6 - Purple
        (0, 128, 128)  # Class 7 - Teal
    ]

    # Load the YOLOv8 model
    model = YOLO("insect_detection4.pt")

    # Open the video file
    cap = cv2.VideoCapture(video_file)

    # Prepare DataFrame for storing detection data
    columns = ["frame", "insect_id", "class", "x", "y", "w", "h"]
    df = pd.DataFrame(columns=columns)

    frame_id = 0
    unique_insect_crops = {}

    # Loop through the video frames
    while cap.isOpened():
        # Read a frame from the video
        success, frame = cap.read()

        if success:
            frame_id += 1

            # Run YOLOv8 tracking on the frame, persisting tracks between frames
            results = model.track(frame, persist=True, tracker="insect_tracker.yaml")

            for result in results:
                boxes = result.boxes.cpu().numpy()
                confidences = boxes.conf
                class_ids = boxes.cls

                for i, box in enumerate(boxes):
                    class_id = int(class_ids[i])
                    confidence = confidences[i]
                    insect_id = int(box.id[0])  # Ensure the ID is an integer

                    # Append detection data to DataFrame
                    new_row = pd.DataFrame({
                        "frame": [frame_id],
                        "insect_id": [insect_id],
                        "class": [class_id],
                        "x": [box.xywh[0][0]],
                        "y": [box.xywh[0][1]],
                        "w": [box.xywh[0][2]],
                        "h": [box.xywh[0][3]]
                    })
                    df = pd.concat([df, new_row], ignore_index=True)

                    # Crop and save the image of the insect
                    if insect_id not in unique_insect_crops:
                        x_center, y_center, width, height = box.xywh[0]
                        x1 = int(x_center - width / 2)
                        y1 = int(y_center - height / 2)
                        x2 = int(x_center + width / 2)
                        y2 = int(y_center + height / 2)
                        insect_crop = frame[y1:y2, x1:x2]
                        crop_path = tempfile.mktemp(suffix=".png")
                        cv2.imwrite(crop_path, insect_crop)
                        unique_insect_crops[insect_id] = (crop_path, label_mapping[class_id])

        else:
            break

    # Release the video capture
    cap.release()

    # Save DataFrame to CSV
    csv_path = tempfile.mktemp(suffix=".csv")
    df.to_csv(csv_path, index=False)

    # Read the DataFrame from the CSV file
    df_from_csv = pd.read_csv(csv_path)

    # Create the plot from the CSV data
    plt.figure(figsize=(10, 6))
    for insect_id in df_from_csv['insect_id'].unique():
        frames = df_from_csv[df_from_csv['insect_id'] == insect_id]['frame']
        insect_class = label_mapping[df_from_csv[df_from_csv['insect_id'] == insect_id]['class'].values[0]]
        plt.plot(frames, [insect_id] * len(frames), 'o-', label=f'{insect_class} {insect_id}')

    plt.xlabel('Frame')
    plt.ylabel('Insect ID')
    plt.title('Temporal distribution of insects')
    plt.legend()
    plt.grid(True)

    # Save the plot to a temporary file
    plot_path = tempfile.mktemp(suffix=".png")
    plt.savefig(plot_path)
    plt.close()

    gallery_items = [(crop_path, f'{label} {insect_id}') for insect_id, (crop_path, label) in unique_insect_crops.items()]

    return plot_path, gallery_items, csv_path

# Create a Gradio interface
example_video = "path_to_example_video.mp4"  # Replace with the actual path to your example video

inputs = gr.Video(label="Input Insect Trap Video", value=example_video)
outputs = [
    gr.Image(label="Insect Detection Plot"),
    gr.Gallery(label="Unique Insect Crops"),  # Added a gallery to display insect crops with labels
    gr.File(label="Download CSV")
]

gr.Interface(fn=process_video, inputs=inputs, outputs=outputs, examples=[example_video]).launch()