import gradio as gr from PIL import Image from transformers import ViTFeatureExtractor, ViTForImageClassification # Load the model and feature extractor feature_extractor = ViTFeatureExtractor.from_pretrained('wambugu1738/crop_leaf_diseases_vit') model = ViTForImageClassification.from_pretrained('wambugu1738/crop_leaf_diseases_vit') # Define prediction function def predict(image): image = Image.fromarray(image) # Convert image from numpy array to PIL Image inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return model.config.id2label[predicted_class_idx] # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(type="numpy"), # Input type as a numpy array outputs="text", title="Crop Disease Detection", description="Upload an image of a crop leaf to detect diseases." ) iface.launch()