Spaces:
Sleeping
Sleeping
iamomtiwari
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,35 @@
|
|
1 |
import gradio as gr # Ensure Gradio is imported
|
2 |
import torch
|
3 |
from transformers import ViTForImageClassification, ViTFeatureExtractor
|
|
|
4 |
from PIL import Image
|
5 |
-
|
6 |
-
# Load model
|
7 |
model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
|
8 |
feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class_labels = {
|
10 |
1: {"label": "Stage Corn___Common_Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
|
11 |
2: {"label": "Stage Corn___Gray_Leaf_Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
|
12 |
-
3: {"label": "Stage Corn___Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
|
13 |
4: {"label": "Stage Corn___Northern_Leaf_Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
|
14 |
5: {"label": "Stage Potato___Early_Blight", "treatment": "Apply fungicides and remove infected plant debris. Practice crop rotation to reduce disease pressure."},
|
15 |
-
6: {"label": "Stage Potato___Healthy", "treatment": "Maintain proper irrigation and fertility practices, and monitor for pests and diseases."},
|
16 |
7: {"label": "Stage Potato___Late_Blight", "treatment": "Apply fungicides, remove infected plant material, and use resistant potato varieties."},
|
17 |
8: {"label": "Stage Rice___Brown_Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
|
18 |
-
9: {"label": "Stage Rice___Healthy", "treatment": "Maintain proper irrigation, fertilization, and pest control measures."},
|
19 |
10: {"label": "Stage Rice___Leaf_Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
|
20 |
11: {"label": "Stage Rice___Neck_Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
|
21 |
12: {"label": "Stage Sugarcane__Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
|
22 |
-
13: {"label": "Stage Sugarcane__Healthy", "treatment": "Maintain healthy soil conditions and proper irrigation."},
|
23 |
14: {"label": "Stage Sugarcane__Red_Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
|
24 |
15: {"label": "Stage Wheat___Brown_Rust", "treatment": "Apply fungicides and practice crop rotation with non-host crops."},
|
25 |
-
16: {"label": "Stage Wheat___Healthy", "treatment": "Continue with good management practices, including proper fertilization and weed control."},
|
26 |
17: {"label": "Stage Wheat___Yellow_Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
|
27 |
}
|
28 |
|
@@ -31,16 +38,28 @@ labels_list = [class_labels[i]["label"] for i in range(1, 18)]
|
|
31 |
|
32 |
# Inference function
|
33 |
def predict(image):
|
|
|
34 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
35 |
with torch.no_grad():
|
36 |
outputs = model(**inputs)
|
37 |
predicted_class_idx = outputs.logits.argmax(-1).item()
|
38 |
-
predicted_label = labels_list[predicted_class_idx]
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Create Gradio Interface
|
46 |
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|
|
|
1 |
import gradio as gr # Ensure Gradio is imported
|
2 |
import torch
|
3 |
from transformers import ViTForImageClassification, ViTFeatureExtractor
|
4 |
+
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
|
5 |
from PIL import Image
|
6 |
+
|
7 |
+
# Load crop model (ViT)
|
8 |
model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
|
9 |
feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
|
10 |
+
|
11 |
+
# Load fallback model (ResNet50 for general image classification)
|
12 |
+
fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
|
13 |
+
fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
|
14 |
+
|
15 |
+
# Define class labels with treatment advice
|
16 |
class_labels = {
|
17 |
1: {"label": "Stage Corn___Common_Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
|
18 |
2: {"label": "Stage Corn___Gray_Leaf_Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
|
19 |
+
3: {"label": "Stage safe Corn___Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
|
20 |
4: {"label": "Stage Corn___Northern_Leaf_Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
|
21 |
5: {"label": "Stage Potato___Early_Blight", "treatment": "Apply fungicides and remove infected plant debris. Practice crop rotation to reduce disease pressure."},
|
22 |
+
6: {"label": "Stage safe Potato___Healthy", "treatment": "Maintain proper irrigation and fertility practices, and monitor for pests and diseases."},
|
23 |
7: {"label": "Stage Potato___Late_Blight", "treatment": "Apply fungicides, remove infected plant material, and use resistant potato varieties."},
|
24 |
8: {"label": "Stage Rice___Brown_Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
|
25 |
+
9: {"label": "Stage safe Rice___Healthy", "treatment": "Maintain proper irrigation, fertilization, and pest control measures."},
|
26 |
10: {"label": "Stage Rice___Leaf_Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
|
27 |
11: {"label": "Stage Rice___Neck_Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
|
28 |
12: {"label": "Stage Sugarcane__Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
|
29 |
+
13: {"label": "Stage safe Sugarcane__Healthy", "treatment": "Maintain healthy soil conditions and proper irrigation."},
|
30 |
14: {"label": "Stage Sugarcane__Red_Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
|
31 |
15: {"label": "Stage Wheat___Brown_Rust", "treatment": "Apply fungicides and practice crop rotation with non-host crops."},
|
32 |
+
16: {"label": "Stage safe Wheat___Healthy", "treatment": "Continue with good management practices, including proper fertilization and weed control."},
|
33 |
17: {"label": "Stage Wheat___Yellow_Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
|
34 |
}
|
35 |
|
|
|
38 |
|
39 |
# Inference function
|
40 |
def predict(image):
|
41 |
+
# First, use the crop disease model (ViT)
|
42 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
43 |
with torch.no_grad():
|
44 |
outputs = model(**inputs)
|
45 |
predicted_class_idx = outputs.logits.argmax(-1).item()
|
|
|
46 |
|
47 |
+
# Check if the predicted label corresponds to a crop disease
|
48 |
+
predicted_label = labels_list[predicted_class_idx]
|
49 |
|
50 |
+
if predicted_class_idx < len(class_labels): # It's a crop disease
|
51 |
+
treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
|
52 |
+
return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
|
53 |
+
else:
|
54 |
+
# If not a crop disease, use the fallback model (ResNet50) for general object detection
|
55 |
+
inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
|
56 |
+
with torch.no_grad():
|
57 |
+
outputs_fallback = fallback_model(**inputs_fallback)
|
58 |
+
predicted_class_idx_fallback = outputs_fallback.logits.argmax(-1).item()
|
59 |
+
|
60 |
+
# Get the fallback prediction label
|
61 |
+
fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
|
62 |
+
return f"Fallback Prediction (Not a Crop): {fallback_label}"
|
63 |
|
64 |
# Create Gradio Interface
|
65 |
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|