|
import React, { useState } from 'react'; |
|
import * as ort from 'onnxruntime-web'; |
|
|
|
function ObjectDetection() { |
|
const [averageTime, setAverageTime] = useState(null); |
|
const [loading, setLoading] = useState(false); |
|
const [images, setImages] = useState([]); |
|
|
|
const handleFileChange = (event) => { |
|
const files = Array.from(event.target.files); |
|
setImages(files.slice(0, 10)); |
|
}; |
|
|
|
const runBenchmark = async () => { |
|
if (images.length === 0) { |
|
alert('Please upload 10 images.'); |
|
return; |
|
} |
|
|
|
setLoading(true); |
|
const repetitions = 50; |
|
let totalInferenceTime = 0; |
|
|
|
try { |
|
|
|
const model = await ort.InferenceSession.create('./model.onnx'); |
|
|
|
for (let rep = 0; rep < repetitions; rep++) { |
|
console.log(`Repetition ${rep + 1} of ${repetitions}`); |
|
|
|
|
|
for (const imageFile of images) { |
|
const startTime = performance.now(); |
|
|
|
|
|
const inputTensor = await preprocessImage(imageFile); |
|
|
|
|
|
const feeds = { input: inputTensor }; |
|
|
|
|
|
await model.run(feeds); |
|
|
|
const endTime = performance.now(); |
|
totalInferenceTime += endTime - startTime; |
|
} |
|
} |
|
|
|
const avgInferenceTime = totalInferenceTime / (repetitions * images.length); |
|
setAverageTime(avgInferenceTime); |
|
} catch (error) { |
|
console.error('Error running inference:', error); |
|
} |
|
|
|
setLoading(false); |
|
}; |
|
|
|
const preprocessImage = async (imageFile) => { |
|
return new Promise((resolve) => { |
|
const img = new Image(); |
|
const reader = new FileReader(); |
|
|
|
reader.onload = () => { |
|
img.src = reader.result; |
|
}; |
|
|
|
img.onload = () => { |
|
const canvas = document.createElement('canvas'); |
|
const context = canvas.getContext('2d'); |
|
|
|
|
|
const modelInputWidth = 320; |
|
const modelInputHeight = 320; |
|
canvas.width = modelInputWidth; |
|
canvas.height = modelInputHeight; |
|
|
|
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight); |
|
|
|
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight); |
|
|
|
|
|
const rgbData = new Uint8Array((imageData.data.length / 4) * 3); |
|
for (let i = 0, j = 0; i < imageData.data.length; i += 4) { |
|
rgbData[j++] = imageData.data[i]; |
|
rgbData[j++] = imageData.data[i + 1]; |
|
rgbData[j++] = imageData.data[i + 2]; |
|
} |
|
|
|
|
|
resolve(new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3])); |
|
}; |
|
|
|
reader.readAsDataURL(imageFile); |
|
}); |
|
}; |
|
|
|
return React.createElement( |
|
'div', |
|
null, |
|
React.createElement('h1', null, 'Object Detection Benchmark (Local Images)'), |
|
React.createElement('input', { |
|
type: 'file', |
|
multiple: true, |
|
accept: 'image/*', |
|
onChange: handleFileChange, |
|
}), |
|
React.createElement( |
|
'button', |
|
{ onClick: runBenchmark, disabled: loading || images.length === 0 }, |
|
loading ? 'Running Benchmark...' : 'Start Benchmark' |
|
), |
|
React.createElement( |
|
'div', |
|
null, |
|
averageTime !== null |
|
? React.createElement( |
|
'h2', |
|
null, |
|
`Average Inference Time: ${averageTime.toFixed(2)} ms` |
|
) |
|
: null |
|
), |
|
React.createElement( |
|
'ul', |
|
null, |
|
images.map((img, index) => |
|
React.createElement('li', { key: index }, img.name) |
|
) |
|
) |
|
); |
|
} |
|
|
|
export default ObjectDetection; |
|
|