import argparse |
import time |
import cv2 |
import numpy as np |
import onnxruntime as ort |
from imagenet_classes import IMAGENET2012_CLASSES |
def parse_arguments(): |
parser = argparse.ArgumentParser(description="Video inference with TensorRT") |
parser.add_argument("--output_video", type=str, help="Path to output video file") |
parser.add_argument("--input_video", type=str, help="Path to input video file") |
parser.add_argument("--webcam", action="store_true", help="Use webcam as input") |
parser.add_argument( |
"--live", action="store_true", help="View video live during inference" |
) |
return parser.parse_args() |
def get_ort_session(model_path): |
providers = [ |
( |
"TensorrtExecutionProvider", |
{ |
"device_id": 0, |
"trt_max_workspace_size": 8589934592, |
"trt_fp16_enable": True, |
"trt_engine_cache_enable": True, |
"trt_engine_cache_path": "./trt_cache", |
"trt_force_sequential_engine_build": False, |
"trt_max_partition_iterations": 10000, |
"trt_min_subgraph_size": 1, |
"trt_builder_optimization_level": 5, |
"trt_timing_cache_enable": True, |
}, |
), |
] |
return ort.InferenceSession(model_path, providers=providers) |
def preprocess_frame(frame): |
resized = cv2.resize(frame, (448, 448), interpolation=cv2.INTER_LINEAR) |
img_numpy = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) |
img_numpy = img_numpy.transpose(2, 0, 1) |
img_numpy = np.expand_dims(img_numpy, axis=0) |
return img_numpy |
def get_top_predictions(output, top_k=5): |
exp_output = np.exp(output - np.max(output, axis=1, keepdims=True)) |
probabilities = exp_output / np.sum(exp_output, axis=1, keepdims=True) |
top_indices = np.argsort(probabilities[0])[-top_k:][::-1] |
top_probs = probabilities[0][top_indices] * 100 |
im_classes = list(IMAGENET2012_CLASSES.values()) |
class_names = [im_classes[i] for i in top_indices] |
return list(zip(class_names, top_probs.tolist())) |
def draw_predictions(frame, predictions, fps): |
fps_text = f"FPS: {fps:.2f}" |
(text_width, text_height), _ = cv2.getTextSize( |
fps_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2 |
) |
text_offset_x = frame.shape[1] - text_width - 10 |
text_offset_y = 30 |
box_coords = ( |
(text_offset_x - 5, text_offset_y + 5), |
(text_offset_x + text_width + 5, text_offset_y - text_height - 5), |
) |
cv2.rectangle( |
frame, box_coords[0], box_coords[1], (139, 0, 0), cv2.FILLED |
) |
cv2.putText( |
frame, |
fps_text, |
(text_offset_x, text_offset_y), |
0.7, |
(255, 255, 255), |
2, |
) |
for i, (name, prob) in enumerate(predictions): |
text = f"{name}: {prob:.2f}%" |
cv2.putText( |
frame, |
text, |
(10, 30 + i * 30), |
0.7, |
(0, 255, 0), |
2, |
) |
model_name = "Model: eva02_large_patch14_448" |
(text_width, text_height), _ = cv2.getTextSize( |
model_name, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2 |
) |
text_x = (frame.shape[1] - text_width) // 2 |
text_y = frame.shape[0] - 20 |
box_coords = ( |
(text_x - 5, text_y + 5), |
(text_x + text_width + 5, text_y - text_height - 5), |
) |
cv2.rectangle( |
frame, box_coords[0], box_coords[1], (0, 0, 255), cv2.FILLED |
) |
cv2.putText( |
frame, |
model_name, |
(text_x, text_y), |
0.7, |
(255, 255, 255), |
2, |
) |
return frame |
def process_video(input_path, output_path, session, live_view=False, use_webcam=False): |
if use_webcam: |
cap = cv2.VideoCapture(0) |
else: |
cap = cv2.VideoCapture(input_path) |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
out = None |
if output_path: |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
input_name = session.get_inputs()[0].name |
output_name = session.get_outputs()[0].name |
frame_count = 0 |
total_time = 0 |
current_fps = 0 |
while cap.isOpened(): |
ret, frame = cap.read() |
if not ret: |
break |
start_time = time.time() |
preprocessed = preprocess_frame(frame) |
output = session.run([output_name], {input_name: preprocessed}) |
predictions = get_top_predictions(output[0]) |
end_time = time.time() |
frame_time = end_time - start_time |
current_fps = 1 / frame_time |
frame_with_predictions = draw_predictions(frame, predictions, current_fps) |
if out: |
out.write(frame_with_predictions) |
if live_view: |
cv2.imshow("Inference", frame_with_predictions) |
if cv2.waitKey(1) & 0xFF == ord("q"): |
break |
total_time += frame_time |
frame_count += 1 |
print( |
f"Processed frame {frame_count}, Time: {frame_time:.3f}s, FPS: {current_fps:.2f}" |
) |
cap.release() |
if out: |
out.release() |
cv2.destroyAllWindows() |
avg_time = total_time / frame_count |
print(f"Average processing time per frame: {avg_time:.3f}s") |
print(f"Average FPS: {1/avg_time:.2f}") |
def main(): |
args = parse_arguments() |
session = get_ort_session("merged_model_compose.onnx") |
if args.webcam: |
process_video(None, args.output_video, session, args.live, use_webcam=True) |
elif args.input_video: |
process_video(args.input_video, args.output_video, session, args.live) |
else: |
print("Error: Please specify either --input_video or --webcam") |
return |
if __name__ == "__main__": |
main() |