File size: 5,257 Bytes
0666166
8789461
 
c666373
8789461
c666373
acf9f46
8789461
 
c666373
 
 
 
 
acf9f46
 
 
 
c666373
5f5775f
c2cea5f
 
c666373
c2cea5f
b660e9c
 
 
 
 
 
 
 
 
 
5f5775f
 
 
b660e9c
5f5775f
 
acf9f46
c666373
bb13609
5f5775f
 
 
0666166
c666373
 
0666166
acf9f46
 
 
 
 
 
 
 
 
 
 
 
 
c666373
 
acf9f46
c666373
 
 
 
 
 
0666166
c666373
 
acf9f46
e4ea851
5f5775f
acf9f46
 
 
 
 
c2cea5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image

# Load crop disease model (ViT)
model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")

# Load fallback model (ResNet50 for general image classification)
fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")

# Load additional ViT model (221k model) for a different classification if the user feedback is "no"
vit_221k_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
vit_221k_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Define class labels with treatment advice
class_labels = {
    1: {"label": "Stage Corn___Common_Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
    2: {"label": "Stage Corn___Gray_Leaf_Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
    3: {"label": "Stage safe Corn___Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
    4: {"label": "Stage Corn___Northern_Leaf_Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
    5: {"label": "Stage Rice___Brown_Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
    6: {"label": "Stage safe Rice___Healthy", "treatment": "Maintain proper irrigation, fertilization, and pest control measures."},
    7: {"label": "Stage Rice___Leaf_Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
    8: {"label": "Stage Rice___Neck_Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
    9: {"label": "Stage Sugarcane__Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
    10: {"label": "Stage safe Sugarcane__Healthy", "treatment": "Maintain healthy soil conditions and proper irrigation."},
    11: {"label": "Stage Sugarcane__Red_Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
    12: {"label": "Stage Wheat___Brown_Rust", "treatment": "Apply fungicides and practice crop rotation with non-host crops."},
    13: {"label": "Stage safe Wheat___Healthy", "treatment": "Continue with good management practices, including proper fertilization and weed control."},
    14: {"label": "Stage Wheat___Yellow_Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
}

# Mapping label indices to class labels
labels_list = [class_labels[i]["label"] for i in range(1, 15)]

# Inference function
def predict(image, feedback):
    # First, use the crop disease model (ViT)
    inputs = feature_extractor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class_idx = outputs.logits.argmax(-1).item()

    # Check if the predicted label corresponds to a crop disease
    predicted_label = labels_list[predicted_class_idx]

    # If the feedback is "no", switch to ViT 221k model for a different class prediction
    if feedback == "no":
        # Use ViT 221k model
        inputs_vit_221k = vit_221k_feature_extractor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs_vit_221k = vit_221k_model(**inputs_vit_221k)
            predicted_class_idx_vit_221k = outputs_vit_221k.logits.argmax(-1).item()
        
        # Get the ViT 221k prediction label
        vit_221k_label = vit_221k_model.config.id2label[predicted_class_idx_vit_221k]
        return f"Fallback ViT 221k Prediction: {vit_221k_label}"

    # If feedback is "yes", return the initial disease prediction and treatment advice
    if predicted_class_idx < len(class_labels):  # It's a crop disease
        treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
        return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
    else:
        # If not a crop disease, use the fallback model (ResNet50) for general object detection
        inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs_fallback = fallback_model(**inputs_fallback)
            predicted_class_idx_fallback = outputs_fallback.logits.argmax(-1).item()

        # Get the fallback prediction label
        fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
        return f"Fallback Prediction (Not a Crop): {fallback_label}"

# Create Gradio Interface
interface = gr.Interface(
    fn=predict, 
    inputs=["image", gr.Radio(["yes", "no"], label="Is the prediction correct?")],
    outputs="text"
)
interface.launch()