Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import tensorflow as tf
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
# ============== HF Transformers / ViT Model ==============
|
9 |
+
from transformers import ViTImageProcessor, ViTForImageClassification
|
10 |
+
|
11 |
+
# ----------- 1. Load the ViT model & processor ------------
|
12 |
+
vit_processor = ViTImageProcessor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
|
13 |
+
vit_model = ViTForImageClassification.from_pretrained(
|
14 |
+
'wambugu1738/crop_leaf_diseases_vit',
|
15 |
+
ignore_mismatched_sizes=True
|
16 |
+
)
|
17 |
+
|
18 |
+
# Define label-to-treatment text (Example for demonstration)
|
19 |
+
vit_label_treatment = {
|
20 |
+
# The HF model was originally for "Corn, Potato, Rice, Wheat diseases".
|
21 |
+
# If it can predict more, add them here.
|
22 |
+
"Corn___Common_rust": "Use recommended fungicides and ensure crop rotation.",
|
23 |
+
"Corn___Cercospora_leaf_spot": "Apply foliar fungicides; ensure good field sanitation.",
|
24 |
+
"Potato___Early_blight": "Apply preventive fungicides; remove infected debris.",
|
25 |
+
"Potato___Late_blight": "Use certified seed tubers; fungicide sprays when conditions favor disease.",
|
26 |
+
"Rice___Leaf_blight": "Use resistant rice varieties, maintain field hygiene.",
|
27 |
+
"Wheat___Leaf_rust": "Plant resistant wheat varieties, apply foliar fungicides if severe.",
|
28 |
+
# Fallback
|
29 |
+
"Unknown": "No specific treatment available."
|
30 |
+
}
|
31 |
+
|
32 |
+
def classify_image_vit(image):
|
33 |
+
# Convert to PIL Image in case input is numpy
|
34 |
+
if not isinstance(image, Image.Image):
|
35 |
+
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
36 |
+
inputs = vit_processor(images=image, return_tensors="pt")
|
37 |
+
outputs = vit_model(**inputs)
|
38 |
+
logits = outputs.logits
|
39 |
+
predicted_class_idx = logits.argmax(-1).item()
|
40 |
+
|
41 |
+
# Predicted label
|
42 |
+
predicted_label = vit_model.config.id2label.get(predicted_class_idx, "Unknown")
|
43 |
+
# Example: If your id2label from HF is something like "corn diseased" or "rice healthy",
|
44 |
+
# match it to the dictionary key for treatments (above). For demonstration:
|
45 |
+
treatment_text = vit_label_treatment.get(predicted_label, "No specific treatment available.")
|
46 |
+
return predicted_label, treatment_text
|
47 |
+
|
48 |
+
|
49 |
+
# ============== TensorFlow Model (plant_model_v5-beta.h5) ==============
|
50 |
+
# Load the model
|
51 |
+
keras_model = tf.keras.models.load_model('plant_model_v5-beta.h5')
|
52 |
+
|
53 |
+
# Define the class names
|
54 |
+
class_names = {
|
55 |
+
0: 'Apple___Apple_scab',
|
56 |
+
1: 'Apple___Black_rot',
|
57 |
+
2: 'Apple___Cedar_apple_rust',
|
58 |
+
3: 'Apple___healthy',
|
59 |
+
4: 'Not a plant',
|
60 |
+
5: 'Blueberry___healthy',
|
61 |
+
6: 'Cherry___Powdery_mildew',
|
62 |
+
7: 'Cherry___healthy',
|
63 |
+
8: 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
|
64 |
+
9: 'Corn___Common_rust',
|
65 |
+
10: 'Corn___Northern_Leaf_Blight',
|
66 |
+
11: 'Corn___healthy',
|
67 |
+
12: 'Grape___Black_rot',
|
68 |
+
13: 'Grape___Esca_(Black_Measles)',
|
69 |
+
14: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
|
70 |
+
15: 'Grape___healthy',
|
71 |
+
16: 'Orange___Haunglongbing_(Citrus_greening)',
|
72 |
+
17: 'Peach___Bacterial_spot',
|
73 |
+
18: 'Peach___healthy',
|
74 |
+
19: 'Pepper,_bell___Bacterial_spot',
|
75 |
+
20: 'Pepper,_bell___healthy',
|
76 |
+
21: 'Potato___Early_blight',
|
77 |
+
22: 'Potato___Late_blight',
|
78 |
+
23: 'Potato___healthy',
|
79 |
+
24: 'Raspberry___healthy',
|
80 |
+
25: 'Soybean___healthy',
|
81 |
+
26: 'Squash___Powdery_mildew',
|
82 |
+
27: 'Strawberry___Leaf_scorch',
|
83 |
+
28: 'Strawberry___healthy',
|
84 |
+
29: 'Tomato___Bacterial_spot',
|
85 |
+
30: 'Tomato___Early_blight',
|
86 |
+
31: 'Tomato___Late_blight',
|
87 |
+
32: 'Tomato___Leaf_Mold',
|
88 |
+
33: 'Tomato___Septoria_leaf_spot',
|
89 |
+
34: 'Tomato___Spider_mites Two-spotted_spider_mite',
|
90 |
+
35: 'Tomato___Target_Spot',
|
91 |
+
36: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
|
92 |
+
37: 'Tomato___Tomato_mosaic_virus',
|
93 |
+
38: 'Tomato___healthy'
|
94 |
+
}
|
95 |
+
|
96 |
+
# Example dictionary of "treatments" for some classes
|
97 |
+
keras_treatments = {
|
98 |
+
'Apple___Apple_scab': "Remove fallen leaves, apply fungicides.",
|
99 |
+
'Apple___Black_rot': "Prune out dead branches; apply copper-based fungicide.",
|
100 |
+
'Corn___Common_rust': "Use resistant hybrids; apply fungicide if needed.",
|
101 |
+
'Corn___Cercospora_leaf_spot Gray_leaf_spot': "Rotate crops; use foliar fungicides.",
|
102 |
+
'Potato___Early_blight': "Use certified seeds; apply preventative fungicides.",
|
103 |
+
'Tomato___Target_Spot': "Use resistant varieties and mulches to reduce disease.",
|
104 |
+
# Fallback:
|
105 |
+
'Unknown': "No specific treatment available."
|
106 |
+
}
|
107 |
+
|
108 |
+
def edge_and_cut(img, threshold1, threshold2):
|
109 |
+
emb_img = img.copy()
|
110 |
+
edges = cv2.Canny(img, threshold1, threshold2)
|
111 |
+
edge_coors = []
|
112 |
+
for i in range(edges.shape[0]):
|
113 |
+
for j in range(edges.shape[1]):
|
114 |
+
if edges[i][j] != 0:
|
115 |
+
edge_coors.append((i, j))
|
116 |
+
|
117 |
+
if len(edge_coors) == 0:
|
118 |
+
return emb_img
|
119 |
+
|
120 |
+
row_min = edge_coors[np.argsort([coor[0] for coor in edge_coors])[0]][0]
|
121 |
+
row_max = edge_coors[np.argsort([coor[0] for coor in edge_coors])[-1]][0]
|
122 |
+
col_min = edge_coors[np.argsort([coor[1] for coor in edge_coors])[0]][1]
|
123 |
+
col_max = edge_coors[np.argsort([coor[1] for coor in edge_coors])[-1]][1]
|
124 |
+
new_img = img[row_min:row_max, col_min:col_max]
|
125 |
+
|
126 |
+
# Simple bounding box in white
|
127 |
+
emb_color = np.array([255], dtype=np.uint8)
|
128 |
+
emb_img[row_min-10:row_min+10, col_min:col_max] = emb_color
|
129 |
+
emb_img[row_max-10:row_max+10, col_min:col_max] = emb_color
|
130 |
+
emb_img[row_min:row_max, col_min-10:col_min+10] = emb_color
|
131 |
+
emb_img[row_min:row_max, col_max-10:col_max+10] = emb_color
|
132 |
+
|
133 |
+
return emb_img
|
134 |
+
|
135 |
+
def classify_and_visualize_keras(image):
|
136 |
+
# Preprocess the image
|
137 |
+
img_array = tf.image.resize(image, [256, 256])
|
138 |
+
img_array = tf.expand_dims(img_array, 0) / 255.0
|
139 |
+
|
140 |
+
# Make a prediction
|
141 |
+
prediction = keras_model.predict(img_array)
|
142 |
+
predicted_class_idx = tf.argmax(prediction[0], axis=-1).numpy()
|
143 |
+
confidence = np.max(prediction[0])
|
144 |
+
|
145 |
+
# Obtain the predicted label
|
146 |
+
predicted_label = class_names.get(predicted_class_idx, "Unknown")
|
147 |
+
|
148 |
+
if confidence < 0.60:
|
149 |
+
class_name = "Uncertain / Not in dataset"
|
150 |
+
bounded_image = image
|
151 |
+
treatment_text = "No treatment recommendation (uncertain prediction)."
|
152 |
+
else:
|
153 |
+
class_name = predicted_label
|
154 |
+
bounded_image = edge_and_cut(image, 200, 400)
|
155 |
+
treatment_text = keras_treatments.get(predicted_label, "No specific treatment available.")
|
156 |
+
|
157 |
+
return class_name, float(confidence), bounded_image, treatment_text
|
158 |
+
|
159 |
+
|
160 |
+
# ============== Combined Gradio App ==============
|
161 |
+
def main_model_selector(model_choice, image):
|
162 |
+
"""
|
163 |
+
Dispatch function based on user choice of model:
|
164 |
+
- 'Vit-model (Corn/Potato/Rice/Wheat)' -> use classify_image_vit
|
165 |
+
- 'Keras-model (Apple/Blueberry/Cherry/etc.)' -> use classify_and_visualize_keras
|
166 |
+
"""
|
167 |
+
if image is None:
|
168 |
+
return "No image provided.", None, None, None
|
169 |
+
|
170 |
+
if model_choice == "ViT (Corn, Potato, Rice, Wheat)":
|
171 |
+
# Return: label, treatment
|
172 |
+
predicted_label, treatment_text = classify_image_vit(image)
|
173 |
+
# For consistency with the Keras model outputs,
|
174 |
+
# we'll keep placeholders for confidence & bounding box
|
175 |
+
return predicted_label, None, image, treatment_text
|
176 |
+
|
177 |
+
elif model_choice == "Keras (Apple, Blueberry, Cherry, etc.)":
|
178 |
+
# Return: class_name, confidence, bounded_image, treatment_text
|
179 |
+
class_name, confidence, bounded_image, treatment_text = classify_and_visualize_keras(image)
|
180 |
+
return class_name, confidence, bounded_image, treatment_text
|
181 |
+
|
182 |
+
else:
|
183 |
+
return "Invalid model choice.", None, None, None
|
184 |
+
|
185 |
+
|
186 |
+
# Create Gradio interface
|
187 |
+
with gr.Blocks() as demo:
|
188 |
+
gr.Markdown("# **Plant Disease Detection**")
|
189 |
+
gr.Markdown(
|
190 |
+
"Select which model you want to use, then upload an image to see the prediction, "
|
191 |
+
"confidence (if applicable), bounding box (if applicable), and a suggested treatment."
|
192 |
+
)
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
model_choice = gr.Radio(
|
196 |
+
choices=["ViT (Corn, Potato, Rice, Wheat)", "Keras (Apple, Blueberry, Cherry, etc.)"],
|
197 |
+
value="Keras (Apple, Blueberry, Cherry, etc.)",
|
198 |
+
label="Select Model"
|
199 |
+
)
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
inp_image = gr.Image(type="numpy", label="Upload Leaf Image")
|
203 |
+
|
204 |
+
# Outputs
|
205 |
+
with gr.Row():
|
206 |
+
out_label = gr.Textbox(label="Predicted Class")
|
207 |
+
out_confidence = gr.Textbox(label="Confidence (If Available)")
|
208 |
+
out_bounded_image = gr.Image(label="Visualization (If Available)")
|
209 |
+
out_treatment = gr.Textbox(label="Treatment Recommendation")
|
210 |
+
|
211 |
+
# Button
|
212 |
+
btn = gr.Button("Classify")
|
213 |
+
|
214 |
+
# Function binding
|
215 |
+
btn.click(
|
216 |
+
fn=main_model_selector,
|
217 |
+
inputs=[model_choice, inp_image],
|
218 |
+
outputs=[out_label, out_confidence, out_bounded_image, out_treatment]
|
219 |
+
)
|
220 |
+
|
221 |
+
# Provide some example images
|
222 |
+
gr.Examples(
|
223 |
+
examples=[
|
224 |
+
["Keras (Apple, Blueberry, Cherry, etc.)", "corn.jpg"],
|
225 |
+
["Keras (Apple, Blueberry, Cherry, etc.)", "grot.jpg"],
|
226 |
+
["Keras (Apple, Blueberry, Cherry, etc.)", "Potato___Early_blight.jpg"],
|
227 |
+
["Keras (Apple, Blueberry, Cherry, etc.)", "Tomato___Target_Spot.jpg"],
|
228 |
+
["ViT (Corn, Potato, Rice, Wheat)", "corn.jpg"],
|
229 |
+
],
|
230 |
+
inputs=[model_choice, inp_image]
|
231 |
+
)
|
232 |
+
|
233 |
+
demo.launch(share=True)
|