Spaces:
Sleeping
Sleeping
File size: 4,996 Bytes
0666166 8789461 c666373 8789461 c666373 acf9f46 8789461 c666373 5f5775f ebaace4 7eef8ce ebaace4 7eef8ce ebaace4 7eef8ce ebaace4 5f5775f b660e9c 5f5775f cca355f 7f05610 cca355f c666373 bb13609 5f5775f cca355f 7f05610 cca355f 7f05610 0666166 7eef8ce cca355f c666373 7f05610 0666166 c666373 7eef8ce 7f05610 855dd12 cca355f 7f05610 e4ea851 5f5775f acf9f46 cca355f acf9f46 855dd12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
# Load crop disease model (ViT)
model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
# Load fallback model (ResNet50 for general image classification)
fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
# Define class labels with treatment advice
class_labels = {
1: {"label": "Stage Corn Common Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
2: {"label": "Stage Corn Gray Leaf Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
3: {"label": "Stage Safe Corn Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
4: {"label": "Stage Corn Northern Leaf Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
5: {"label": "Stage Rice Brown Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
6: {"label": "Stage Safe Rice Healthy", "treatment": "Maintain proper irrigation, fertilization, and pest control measures."},
7: {"label": "Stage Rice Leaf Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
8: {"label": "Stage Rice Neck Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
9: {"label": "Stage Sugarcane Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
10: {"label": "Stage Safe Sugarcane Healthy", "treatment": "Maintain healthy soil conditions and proper irrigation."},
11: {"label": "Stage Sugarcane Red Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
12: {"label": "Stage Wheat Brown Rust", "treatment": "Apply fungicides and practice crop rotation with non-host crops."},
13: {"label": "Stage Safe Wheat Healthy", "treatment": "Continue with good management practices, including proper fertilization and weed control."},
14: {"label": "Stage Wheat Yellow Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
}
# Mapping label indices to class labels
labels_list = [class_labels[i]["label"] for i in range(1, 15)]
# Confidence threshold for ViT model
CONFIDENCE_THRESHOLD = 0.5
# Inference function with fuzzy confidence
def predict(image):
# First, use the crop disease model (ViT)
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
confidences = torch.softmax(logits, dim=-1)
predicted_class_idx = logits.argmax(-1).item()
confidence = confidences[0, predicted_class_idx].item()
# If confidence is below the threshold, use the fallback model
if confidence < CONFIDENCE_THRESHOLD:
inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs_fallback = fallback_model(**inputs_fallback)
logits_fallback = outputs_fallback.logits
confidences_fallback = torch.softmax(logits_fallback, dim=-1)
predicted_class_idx_fallback = logits_fallback.argmax(-1).item()
fallback_confidence = confidences_fallback[0, predicted_class_idx_fallback].item()
# Get the fallback prediction label
fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
return (
f"Low confidence in ViT model ({confidence * 100:.2f}%).\n"
f"ResNet-50 predicts: {fallback_label} ({fallback_confidence * 100:.2f}%).\n\n"
"If this does not match your input, please try another image."
)
# If confidence is above the threshold, return the ViT prediction and treatment advice
predicted_label = labels_list[predicted_class_idx]
treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
return (
f"Disease: {predicted_label} ({confidence * 100:.2f}%)\n\n"
f"Treatment Advice: {treatment_advice}"
)
# Create Gradio Interface
interface = gr.Interface(
fn=predict,
inputs="image",
outputs="text",
title="Crop Disease Detection",
description="Upload an image of a crop plant to detect diseases. If confidence is low, ResNet-50 will classify the image."
)
if __name__ == "__main__":
interface.launch()
|