eva02_large_patch14_448 / 09_video_inference.py
dnth's picture
Upload 11 files
88359db verified
raw
history blame
6.32 kB
import argparse
import time
import cv2
import numpy as np
import onnxruntime as ort
from imagenet_classes import IMAGENET2012_CLASSES
def parse_arguments():
parser = argparse.ArgumentParser(description="Video inference with TensorRT")
parser.add_argument("--output_video", type=str, help="Path to output video file")
parser.add_argument("--input_video", type=str, help="Path to input video file")
parser.add_argument("--webcam", action="store_true", help="Use webcam as input")
parser.add_argument(
"--live", action="store_true", help="View video live during inference"
)
return parser.parse_args()
def get_ort_session(model_path):
providers = [
(
"TensorrtExecutionProvider",
{
"device_id": 0,
"trt_max_workspace_size": 8589934592,
"trt_fp16_enable": True,
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "./trt_cache",
"trt_force_sequential_engine_build": False,
"trt_max_partition_iterations": 10000,
"trt_min_subgraph_size": 1,
"trt_builder_optimization_level": 5,
"trt_timing_cache_enable": True,
},
),
]
return ort.InferenceSession(model_path, providers=providers)
def preprocess_frame(frame):
# Use cv2 for resizing instead of PIL for better performance
resized = cv2.resize(frame, (448, 448), interpolation=cv2.INTER_LINEAR)
img_numpy = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32)
img_numpy = img_numpy.transpose(2, 0, 1)
img_numpy = np.expand_dims(img_numpy, axis=0)
return img_numpy
def get_top_predictions(output, top_k=5):
# Apply softmax
exp_output = np.exp(output - np.max(output, axis=1, keepdims=True))
probabilities = exp_output / np.sum(exp_output, axis=1, keepdims=True)
# Get top k indices and probabilities
top_indices = np.argsort(probabilities[0])[-top_k:][::-1]
top_probs = probabilities[0][top_indices] * 100
im_classes = list(IMAGENET2012_CLASSES.values())
class_names = [im_classes[i] for i in top_indices]
return list(zip(class_names, top_probs.tolist()))
def draw_predictions(frame, predictions, fps):
# Draw FPS in the top right corner with dark blue background
fps_text = f"FPS: {fps:.2f}"
(text_width, text_height), _ = cv2.getTextSize(
fps_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2
)
text_offset_x = frame.shape[1] - text_width - 10
text_offset_y = 30
box_coords = (
(text_offset_x - 5, text_offset_y + 5),
(text_offset_x + text_width + 5, text_offset_y - text_height - 5),
)
cv2.rectangle(
frame, box_coords[0], box_coords[1], (139, 0, 0), cv2.FILLED
) # Dark blue background
cv2.putText(
frame,
fps_text,
(text_offset_x, text_offset_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255), # White text
2,
)
# Draw predictions
for i, (name, prob) in enumerate(predictions):
text = f"{name}: {prob:.2f}%"
cv2.putText(
frame,
text,
(10, 30 + i * 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
# Draw model name at the bottom of the frame with red background
model_name = "Model: eva02_large_patch14_448"
(text_width, text_height), _ = cv2.getTextSize(
model_name, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2
)
text_x = (frame.shape[1] - text_width) // 2
text_y = frame.shape[0] - 20
box_coords = (
(text_x - 5, text_y + 5),
(text_x + text_width + 5, text_y - text_height - 5),
)
cv2.rectangle(
frame, box_coords[0], box_coords[1], (0, 0, 255), cv2.FILLED
) # Red background
cv2.putText(
frame,
model_name,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255), # White text
2,
)
return frame
def process_video(input_path, output_path, session, live_view=False, use_webcam=False):
if use_webcam:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(input_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
out = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
frame_count = 0
total_time = 0
current_fps = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
start_time = time.time()
preprocessed = preprocess_frame(frame)
output = session.run([output_name], {input_name: preprocessed})
predictions = get_top_predictions(output[0])
end_time = time.time()
frame_time = end_time - start_time
current_fps = 1 / frame_time
frame_with_predictions = draw_predictions(frame, predictions, current_fps)
if out:
out.write(frame_with_predictions)
if live_view:
cv2.imshow("Inference", frame_with_predictions)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
total_time += frame_time
frame_count += 1
print(
f"Processed frame {frame_count}, Time: {frame_time:.3f}s, FPS: {current_fps:.2f}"
)
cap.release()
if out:
out.release()
cv2.destroyAllWindows()
avg_time = total_time / frame_count
print(f"Average processing time per frame: {avg_time:.3f}s")
print(f"Average FPS: {1/avg_time:.2f}")
def main():
args = parse_arguments()
session = get_ort_session("merged_model_compose.onnx")
if args.webcam:
process_video(None, args.output_video, session, args.live, use_webcam=True)
elif args.input_video:
process_video(args.input_video, args.output_video, session, args.live)
else:
print("Error: Please specify either --input_video or --webcam")
return
if __name__ == "__main__":
main()