์ฌ์ฉ์์
import onnxruntime as ort
import numpy as np
from transformers import AutoFeatureExtractor
from PIL import Image
# ONNX ๋ชจ๋ธ ๊ฒฝ๋ก
onnx_model_path = r'C:\mobilevit_model.onnx'
# ONNX ๋ฐํ์ ์ธ์
์ด๊ธฐํ
ort_session = ort.InferenceSession(onnx_model_path)
# ์๋ก์ด ์ด๋ฏธ์ง ์์ธก ํจ์ ์ ์
def predict_image(image_path):
# MobileViT ๋ชจ๋ธ์ ๋ง๋ ํน์ง ์ถ์ถ๊ธฐ ๋ก๋
feature_extractor = AutoFeatureExtractor.from_pretrained("apple/mobilevit-small")
# ์ด๋ฏธ์ง๋ฅผ ๋ก๋ํ๊ณ RGB๋ก ๋ณํ
image = Image.open(image_path).convert("RGB")
# ์ด๋ฏธ์ง๋ฅผ ํน์ง ์ถ์ถ๊ธฐ๋ก ์ ์ฒ๋ฆฌ
inputs = feature_extractor(images=image, return_tensors="np")
input_array = inputs['pixel_values'] # ONNX๋ Numpy ํ์์ ์ฌ์ฉ
# ONNX ๋ชจ๋ธ์ ์
๋ ฅ ์ ๋ฌ ๋ฐ ์ถ๋ก
ort_inputs = {ort_session.get_inputs()[0].name: input_array}
ort_outputs = ort_session.run(None, ort_inputs)
# ๊ฒฐ๊ณผ ํด์
logits = ort_outputs[0]
predicted_class = np.argmax(logits, axis=-1).item()
return "๊ทธ๋ฅ ์ฌ์ง" if predicted_class == 1 else "๋ก๋งจ์ค ์ค์บ ์ฌ์ง"
# ์์ธก ์์
image_path = r'C:\1234567.jpg'
result = predict_image(image_path)
print(result)
- Downloads last month
- 11