์‚ฌ์šฉ์˜ˆ์‹œ

import onnxruntime as ort
import numpy as np
from transformers import AutoFeatureExtractor
from PIL import Image

# ONNX ๋ชจ๋ธ ๊ฒฝ๋กœ
onnx_model_path = r'C:\mobilevit_model.onnx'

# ONNX ๋Ÿฐํƒ€์ž„ ์„ธ์…˜ ์ดˆ๊ธฐํ™”
ort_session = ort.InferenceSession(onnx_model_path)

# ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€ ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict_image(image_path):
    # MobileViT ๋ชจ๋ธ์— ๋งž๋Š” ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋กœ๋“œ
    feature_extractor = AutoFeatureExtractor.from_pretrained("apple/mobilevit-small")
    
    # ์ด๋ฏธ์ง€๋ฅผ ๋กœ๋“œํ•˜๊ณ  RGB๋กœ ๋ณ€ํ™˜
    image = Image.open(image_path).convert("RGB")
    
    # ์ด๋ฏธ์ง€๋ฅผ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ ์ „์ฒ˜๋ฆฌ
    inputs = feature_extractor(images=image, return_tensors="np")
    input_array = inputs['pixel_values']  # ONNX๋Š” Numpy ํ˜•์‹์„ ์‚ฌ์šฉ
    
    # ONNX ๋ชจ๋ธ์— ์ž…๋ ฅ ์ „๋‹ฌ ๋ฐ ์ถ”๋ก 
    ort_inputs = {ort_session.get_inputs()[0].name: input_array}
    ort_outputs = ort_session.run(None, ort_inputs)
    
    # ๊ฒฐ๊ณผ ํ•ด์„
    logits = ort_outputs[0]
    predicted_class = np.argmax(logits, axis=-1).item()
    
    return "๊ทธ๋ƒฅ ์‚ฌ์ง„" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ  ์‚ฌ์ง„"

# ์˜ˆ์ธก ์˜ˆ์‹œ
image_path = r'C:\1234567.jpg'
result = predict_image(image_path)
print(result)
Downloads last month
11
Safetensors
Model size
4.95M params
Tensor type
F32
ยท
Inference API
Unable to determine this model's library. Check the docs .