Nekshay commited on
Commit
6d98f13
·
verified ·
1 Parent(s): c0d965d

Update onnx_time_Inferance.js

Browse files
Files changed (1) hide show
  1. onnx_time_Inferance.js +62 -45
onnx_time_Inferance.js CHANGED
@@ -1,18 +1,24 @@
1
- import React, { useRef, useState } from 'react';
2
- import Webcam from 'react-webcam';
3
  import * as ort from 'onnxruntime-web';
4
 
5
  function ObjectDetection() {
6
  const [averageTime, setAverageTime] = useState(null);
7
  const [loading, setLoading] = useState(false);
8
- const webcamRef = useRef(null);
 
 
 
 
 
9
 
10
  const runBenchmark = async () => {
11
- if (!webcamRef.current) return;
12
- setLoading(true);
 
 
13
 
 
14
  const repetitions = 50;
15
- const imageCount = 10;
16
  let totalInferenceTime = 0;
17
 
18
  try {
@@ -22,15 +28,12 @@ function ObjectDetection() {
22
  for (let rep = 0; rep < repetitions; rep++) {
23
  console.log(`Repetition ${rep + 1} of ${repetitions}`);
24
 
25
- // Capture 10 images and measure inference time
26
- for (let i = 0; i < imageCount; i++) {
27
  const startTime = performance.now();
28
 
29
- // Capture image from webcam
30
- const imageSrc = webcamRef.current.getScreenshot();
31
-
32
- // Preprocess the image
33
- const inputTensor = await preprocessImage(imageSrc);
34
 
35
  // Define model input
36
  const feeds = { input: inputTensor };
@@ -43,7 +46,7 @@ function ObjectDetection() {
43
  }
44
  }
45
 
46
- const avgInferenceTime = totalInferenceTime / (repetitions * imageCount);
47
  setAverageTime(avgInferenceTime);
48
  } catch (error) {
49
  console.error('Error running inference:', error);
@@ -52,51 +55,58 @@ function ObjectDetection() {
52
  setLoading(false);
53
  };
54
 
55
- const preprocessImage = async (imageSrc) => {
56
- const img = new Image();
57
- img.src = imageSrc;
58
- await new Promise((resolve) => (img.onload = resolve));
59
 
60
- const canvas = document.createElement('canvas');
61
- const context = canvas.getContext('2d');
 
62
 
63
- // Resize to model input size
64
- const modelInputWidth = 320; // Replace with your model's input width
65
- const modelInputHeight = 320; // Replace with your model's input height
66
- canvas.width = modelInputWidth;
67
- canvas.height = modelInputHeight;
68
 
69
- context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
 
 
 
 
70
 
71
- const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
72
 
73
- // Convert RGBA to RGB
74
- const rgbData = new Uint8Array((imageData.data.length / 4) * 3); // 3 channels for RGB
75
- for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
76
- rgbData[j++] = imageData.data[i]; // R
77
- rgbData[j++] = imageData.data[i + 1]; // G
78
- rgbData[j++] = imageData.data[i + 2]; // B
79
- // Skip A (alpha) channel
80
- }
81
 
82
- // Create a tensor with shape [1, 320, 320, 3]
83
- return new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3]);
 
 
 
 
 
 
 
 
 
 
 
 
84
  };
85
 
86
  return React.createElement(
87
  'div',
88
  null,
89
- React.createElement('h1', null, 'Object Detection Benchmark'),
90
- React.createElement(Webcam, {
91
- audio: false,
92
- ref: webcamRef,
93
- screenshotFormat: 'image/jpeg',
94
- width: 320,
95
- height: 320,
96
  }),
97
  React.createElement(
98
  'button',
99
- { onClick: runBenchmark, disabled: loading },
100
  loading ? 'Running Benchmark...' : 'Start Benchmark'
101
  ),
102
  React.createElement(
@@ -109,6 +119,13 @@ function ObjectDetection() {
109
  `Average Inference Time: ${averageTime.toFixed(2)} ms`
110
  )
111
  : null
 
 
 
 
 
 
 
112
  )
113
  );
114
  }
 
1
+ import React, { useState } from 'react';
 
2
  import * as ort from 'onnxruntime-web';
3
 
4
  function ObjectDetection() {
5
  const [averageTime, setAverageTime] = useState(null);
6
  const [loading, setLoading] = useState(false);
7
+ const [images, setImages] = useState([]);
8
+
9
+ const handleFileChange = (event) => {
10
+ const files = Array.from(event.target.files);
11
+ setImages(files.slice(0, 10)); // Limit to the first 10 images
12
+ };
13
 
14
  const runBenchmark = async () => {
15
+ if (images.length === 0) {
16
+ alert('Please upload 10 images.');
17
+ return;
18
+ }
19
 
20
+ setLoading(true);
21
  const repetitions = 50;
 
22
  let totalInferenceTime = 0;
23
 
24
  try {
 
28
  for (let rep = 0; rep < repetitions; rep++) {
29
  console.log(`Repetition ${rep + 1} of ${repetitions}`);
30
 
31
+ // Process each image
32
+ for (const imageFile of images) {
33
  const startTime = performance.now();
34
 
35
+ // Convert image to tensor
36
+ const inputTensor = await preprocessImage(imageFile);
 
 
 
37
 
38
  // Define model input
39
  const feeds = { input: inputTensor };
 
46
  }
47
  }
48
 
49
+ const avgInferenceTime = totalInferenceTime / (repetitions * images.length);
50
  setAverageTime(avgInferenceTime);
51
  } catch (error) {
52
  console.error('Error running inference:', error);
 
55
  setLoading(false);
56
  };
57
 
58
+ const preprocessImage = async (imageFile) => {
59
+ return new Promise((resolve) => {
60
+ const img = new Image();
61
+ const reader = new FileReader();
62
 
63
+ reader.onload = () => {
64
+ img.src = reader.result;
65
+ };
66
 
67
+ img.onload = () => {
68
+ const canvas = document.createElement('canvas');
69
+ const context = canvas.getContext('2d');
 
 
70
 
71
+ // Resize to model input size
72
+ const modelInputWidth = 320; // Replace with your model's input width
73
+ const modelInputHeight = 320; // Replace with your model's input height
74
+ canvas.width = modelInputWidth;
75
+ canvas.height = modelInputHeight;
76
 
77
+ context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
78
 
79
+ const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
 
 
 
 
 
 
 
80
 
81
+ // Convert RGBA to RGB
82
+ const rgbData = new Uint8Array((imageData.data.length / 4) * 3); // 3 channels for RGB
83
+ for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
84
+ rgbData[j++] = imageData.data[i]; // R
85
+ rgbData[j++] = imageData.data[i + 1]; // G
86
+ rgbData[j++] = imageData.data[i + 2]; // B
87
+ }
88
+
89
+ // Create a tensor with shape [1, 320, 320, 3]
90
+ resolve(new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3]));
91
+ };
92
+
93
+ reader.readAsDataURL(imageFile);
94
+ });
95
  };
96
 
97
  return React.createElement(
98
  'div',
99
  null,
100
+ React.createElement('h1', null, 'Object Detection Benchmark (Local Images)'),
101
+ React.createElement('input', {
102
+ type: 'file',
103
+ multiple: true,
104
+ accept: 'image/*',
105
+ onChange: handleFileChange,
 
106
  }),
107
  React.createElement(
108
  'button',
109
+ { onClick: runBenchmark, disabled: loading || images.length === 0 },
110
  loading ? 'Running Benchmark...' : 'Start Benchmark'
111
  ),
112
  React.createElement(
 
119
  `Average Inference Time: ${averageTime.toFixed(2)} ms`
120
  )
121
  : null
122
+ ),
123
+ React.createElement(
124
+ 'ul',
125
+ null,
126
+ images.map((img, index) =>
127
+ React.createElement('li', { key: index }, img.name)
128
+ )
129
  )
130
  );
131
  }