|
import argparse |
|
import time |
|
|
|
import cv2 |
|
import numpy as np |
|
import onnxruntime as ort |
|
|
|
from imagenet_classes import IMAGENET2012_CLASSES |
|
|
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser(description="Video inference with TensorRT") |
|
parser.add_argument("--output_video", type=str, help="Path to output video file") |
|
parser.add_argument("--input_video", type=str, help="Path to input video file") |
|
parser.add_argument("--webcam", action="store_true", help="Use webcam as input") |
|
parser.add_argument( |
|
"--live", action="store_true", help="View video live during inference" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def get_ort_session(model_path): |
|
providers = [ |
|
( |
|
"TensorrtExecutionProvider", |
|
{ |
|
"device_id": 0, |
|
"trt_max_workspace_size": 8589934592, |
|
"trt_fp16_enable": True, |
|
"trt_engine_cache_enable": True, |
|
"trt_engine_cache_path": "./trt_cache", |
|
"trt_force_sequential_engine_build": False, |
|
"trt_max_partition_iterations": 10000, |
|
"trt_min_subgraph_size": 1, |
|
"trt_builder_optimization_level": 5, |
|
"trt_timing_cache_enable": True, |
|
}, |
|
), |
|
] |
|
return ort.InferenceSession(model_path, providers=providers) |
|
|
|
|
|
def preprocess_frame(frame): |
|
|
|
resized = cv2.resize(frame, (448, 448), interpolation=cv2.INTER_LINEAR) |
|
img_numpy = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) |
|
img_numpy = img_numpy.transpose(2, 0, 1) |
|
img_numpy = np.expand_dims(img_numpy, axis=0) |
|
return img_numpy |
|
|
|
|
|
def get_top_predictions(output, top_k=5): |
|
|
|
exp_output = np.exp(output - np.max(output, axis=1, keepdims=True)) |
|
probabilities = exp_output / np.sum(exp_output, axis=1, keepdims=True) |
|
|
|
|
|
top_indices = np.argsort(probabilities[0])[-top_k:][::-1] |
|
top_probs = probabilities[0][top_indices] * 100 |
|
|
|
im_classes = list(IMAGENET2012_CLASSES.values()) |
|
class_names = [im_classes[i] for i in top_indices] |
|
|
|
return list(zip(class_names, top_probs.tolist())) |
|
|
|
|
|
def draw_predictions(frame, predictions, fps): |
|
|
|
fps_text = f"FPS: {fps:.2f}" |
|
(text_width, text_height), _ = cv2.getTextSize( |
|
fps_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2 |
|
) |
|
text_offset_x = frame.shape[1] - text_width - 10 |
|
text_offset_y = 30 |
|
box_coords = ( |
|
(text_offset_x - 5, text_offset_y + 5), |
|
(text_offset_x + text_width + 5, text_offset_y - text_height - 5), |
|
) |
|
cv2.rectangle( |
|
frame, box_coords[0], box_coords[1], (139, 0, 0), cv2.FILLED |
|
) |
|
cv2.putText( |
|
frame, |
|
fps_text, |
|
(text_offset_x, text_offset_y), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.7, |
|
(255, 255, 255), |
|
2, |
|
) |
|
|
|
|
|
for i, (name, prob) in enumerate(predictions): |
|
text = f"{name}: {prob:.2f}%" |
|
cv2.putText( |
|
frame, |
|
text, |
|
(10, 30 + i * 30), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.7, |
|
(0, 255, 0), |
|
2, |
|
) |
|
|
|
|
|
model_name = "Model: eva02_large_patch14_448" |
|
(text_width, text_height), _ = cv2.getTextSize( |
|
model_name, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2 |
|
) |
|
text_x = (frame.shape[1] - text_width) // 2 |
|
text_y = frame.shape[0] - 20 |
|
box_coords = ( |
|
(text_x - 5, text_y + 5), |
|
(text_x + text_width + 5, text_y - text_height - 5), |
|
) |
|
cv2.rectangle( |
|
frame, box_coords[0], box_coords[1], (0, 0, 255), cv2.FILLED |
|
) |
|
cv2.putText( |
|
frame, |
|
model_name, |
|
(text_x, text_y), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.7, |
|
(255, 255, 255), |
|
2, |
|
) |
|
|
|
return frame |
|
|
|
|
|
def process_video(input_path, output_path, session, live_view=False, use_webcam=False): |
|
if use_webcam: |
|
cap = cv2.VideoCapture(0) |
|
else: |
|
cap = cv2.VideoCapture(input_path) |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
|
out = None |
|
if output_path: |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
input_name = session.get_inputs()[0].name |
|
output_name = session.get_outputs()[0].name |
|
|
|
frame_count = 0 |
|
total_time = 0 |
|
current_fps = 0 |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
start_time = time.time() |
|
|
|
preprocessed = preprocess_frame(frame) |
|
output = session.run([output_name], {input_name: preprocessed}) |
|
predictions = get_top_predictions(output[0]) |
|
|
|
end_time = time.time() |
|
frame_time = end_time - start_time |
|
current_fps = 1 / frame_time |
|
|
|
frame_with_predictions = draw_predictions(frame, predictions, current_fps) |
|
|
|
if out: |
|
out.write(frame_with_predictions) |
|
|
|
if live_view: |
|
cv2.imshow("Inference", frame_with_predictions) |
|
if cv2.waitKey(1) & 0xFF == ord("q"): |
|
break |
|
|
|
total_time += frame_time |
|
frame_count += 1 |
|
|
|
print( |
|
f"Processed frame {frame_count}, Time: {frame_time:.3f}s, FPS: {current_fps:.2f}" |
|
) |
|
|
|
cap.release() |
|
if out: |
|
out.release() |
|
cv2.destroyAllWindows() |
|
|
|
avg_time = total_time / frame_count |
|
print(f"Average processing time per frame: {avg_time:.3f}s") |
|
print(f"Average FPS: {1/avg_time:.2f}") |
|
|
|
|
|
def main(): |
|
args = parse_arguments() |
|
session = get_ort_session("merged_model_compose.onnx") |
|
|
|
if args.webcam: |
|
process_video(None, args.output_video, session, args.live, use_webcam=True) |
|
elif args.input_video: |
|
process_video(args.input_video, args.output_video, session, args.live) |
|
else: |
|
print("Error: Please specify either --input_video or --webcam") |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|