Car_VS_Rest / onnx_time_Inferance.js
Nekshay's picture
Update onnx_time_Inferance.js
6d98f13 verified
import React, { useState } from 'react';
import * as ort from 'onnxruntime-web';
function ObjectDetection() {
const [averageTime, setAverageTime] = useState(null);
const [loading, setLoading] = useState(false);
const [images, setImages] = useState([]);
const handleFileChange = (event) => {
const files = Array.from(event.target.files);
setImages(files.slice(0, 10)); // Limit to the first 10 images
};
const runBenchmark = async () => {
if (images.length === 0) {
alert('Please upload 10 images.');
return;
}
setLoading(true);
const repetitions = 50;
let totalInferenceTime = 0;
try {
// Load the ONNX model once before the loop
const model = await ort.InferenceSession.create('./model.onnx');
for (let rep = 0; rep < repetitions; rep++) {
console.log(`Repetition ${rep + 1} of ${repetitions}`);
// Process each image
for (const imageFile of images) {
const startTime = performance.now();
// Convert image to tensor
const inputTensor = await preprocessImage(imageFile);
// Define model input
const feeds = { input: inputTensor };
// Run inference
await model.run(feeds);
const endTime = performance.now();
totalInferenceTime += endTime - startTime;
}
}
const avgInferenceTime = totalInferenceTime / (repetitions * images.length);
setAverageTime(avgInferenceTime);
} catch (error) {
console.error('Error running inference:', error);
}
setLoading(false);
};
const preprocessImage = async (imageFile) => {
return new Promise((resolve) => {
const img = new Image();
const reader = new FileReader();
reader.onload = () => {
img.src = reader.result;
};
img.onload = () => {
const canvas = document.createElement('canvas');
const context = canvas.getContext('2d');
// Resize to model input size
const modelInputWidth = 320; // Replace with your model's input width
const modelInputHeight = 320; // Replace with your model's input height
canvas.width = modelInputWidth;
canvas.height = modelInputHeight;
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
// Convert RGBA to RGB
const rgbData = new Uint8Array((imageData.data.length / 4) * 3); // 3 channels for RGB
for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
rgbData[j++] = imageData.data[i]; // R
rgbData[j++] = imageData.data[i + 1]; // G
rgbData[j++] = imageData.data[i + 2]; // B
}
// Create a tensor with shape [1, 320, 320, 3]
resolve(new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3]));
};
reader.readAsDataURL(imageFile);
});
};
return React.createElement(
'div',
null,
React.createElement('h1', null, 'Object Detection Benchmark (Local Images)'),
React.createElement('input', {
type: 'file',
multiple: true,
accept: 'image/*',
onChange: handleFileChange,
}),
React.createElement(
'button',
{ onClick: runBenchmark, disabled: loading || images.length === 0 },
loading ? 'Running Benchmark...' : 'Start Benchmark'
),
React.createElement(
'div',
null,
averageTime !== null
? React.createElement(
'h2',
null,
`Average Inference Time: ${averageTime.toFixed(2)} ms`
)
: null
),
React.createElement(
'ul',
null,
images.map((img, index) =>
React.createElement('li', { key: index }, img.name)
)
)
);
}
export default ObjectDetection;