EfficientNet Parkinson's Prediction Model π€
This repository contains the Hugging Face EfficientNet model for predicting Parkinson's disease using patient drawings with an accuracy of around 83%. Made w/ EfficientNet and Torch.
Overview
Parkinson's disease is a progressive nervous system disorder that affects movement. Symptoms start gradually, sometimes starting with a barely noticeable tremor in just one hand. Tremors are common, but the disorder also commonly causes stiffness or slowing of movement.
My model uses the EfficientNet architecture to predict the likelihood of Parkinson's disease in patients by analysing their drawings. Feel free to open a pull request and contribute if you want to.
Dataset
The dataset used to train this model was provided by Kaggle.
Usage
import torch
from transformers import AutoModel
from torch import nn
from PIL import Image
import numpy as np
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the trained model
model = AutoModel.from_pretrained('/content/final')
# Move the model to the device
model = model.to(device)
# Load and resize new image(s)
image_size = (224, 224)
new_image = Image.open('/content/health.png').convert('RGB').resize(image_size)
new_image = np.array(new_image)
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
# Move the data to the device
new_image = new_image.to(device)
# Make predictions using the trained model
with torch.no_grad():
predictions = model(new_image)
logits = predictions.last_hidden_state
logits = logits.view(logits.shape[0], -1)
num_classes=2
feature_reducer = nn.Linear(logits.shape[1], num_classes)
logits = logits.to(device)
feature_reducer = feature_reducer.to(device)
logits = feature_reducer(logits)
predicted_class = torch.argmax(logits, dim=1).item()
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
if(predicted_class == 0):
print(f'Predicted class: Parkinson\'s with confidence {confidence:.2f}')
else:
print(f'Predicted class: Healthy with confidence {confidence:.2f}')
- Downloads last month
- 8
Inference API (serverless) has been turned off for this model.