iamomtiwari's picture
Create app.py
1867c2e verified
raw
history blame
978 Bytes
import gradio as gr
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
# Load the model and feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
model = ViTForImageClassification.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
# Define prediction function
def predict(image):
image = Image.fromarray(image) # Convert image from numpy array to PIL Image
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="numpy"), # Input type as a numpy array
outputs="text",
title="Crop Disease Detection",
description="Upload an image of a crop leaf to detect diseases."
)
iface.launch()