|
import torch |
|
from torchvision import transforms |
|
from transformers import AutoImageProcessor, MobileViTV2ForImageClassification |
|
import gradio as gr |
|
|
|
model_url = "MichalMlodawski/open-closed-eye-classification-mobilevitv2-1.0" |
|
|
|
image_processor = AutoImageProcessor.from_pretrained(model_url) |
|
model = MobileViTV2ForImageClassification.from_pretrained(model_url) |
|
model.eval() |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def classify_image(image): |
|
image = image.convert("RGB") |
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
inputs = image_processor(images=image, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
labels = ["👁️ Eye Closed", "👁️ Eye Open"] |
|
prediction = labels[predicted.item()] |
|
confidence = confidence.item() * 100 |
|
|
|
confidence_bar = "🟩" * int(confidence // 10) + "⬜" * (10 - int(confidence // 10)) |
|
|
|
return f"🔍 Prediction: {prediction}\n🎯 Confidence: {confidence:.2f}% {confidence_bar}" |
|
|
|
def gradio_interface(image): |
|
return classify_image(image) |
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Image(type="pil", label="📷 Upload an image"), |
|
outputs=gr.Textbox(label="🖥️ Classification Result"), |
|
title="👁️ Eye State Classification 👁️", |
|
description="Upload an image to classify whether the eye is open or closed. Let's see what we can spot! 👀", |
|
theme=gr.themes.Soft(primary_hue="blue"), |
|
allow_flagging="never" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |