Abrar20 commited on
Commit
8fc219d
·
verified ·
1 Parent(s): cd7d8c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -0
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import tensorflow as tf
5
+ import torch
6
+ from PIL import Image
7
+
8
+ # ============== HF Transformers / ViT Model ==============
9
+ from transformers import ViTImageProcessor, ViTForImageClassification
10
+
11
+ # ----------- 1. Load the ViT model & processor ------------
12
+ vit_processor = ViTImageProcessor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
13
+ vit_model = ViTForImageClassification.from_pretrained(
14
+ 'wambugu1738/crop_leaf_diseases_vit',
15
+ ignore_mismatched_sizes=True
16
+ )
17
+
18
+ # Define label-to-treatment text (Example for demonstration)
19
+ vit_label_treatment = {
20
+ # The HF model was originally for "Corn, Potato, Rice, Wheat diseases".
21
+ # If it can predict more, add them here.
22
+ "Corn___Common_rust": "Use recommended fungicides and ensure crop rotation.",
23
+ "Corn___Cercospora_leaf_spot": "Apply foliar fungicides; ensure good field sanitation.",
24
+ "Potato___Early_blight": "Apply preventive fungicides; remove infected debris.",
25
+ "Potato___Late_blight": "Use certified seed tubers; fungicide sprays when conditions favor disease.",
26
+ "Rice___Leaf_blight": "Use resistant rice varieties, maintain field hygiene.",
27
+ "Wheat___Leaf_rust": "Plant resistant wheat varieties, apply foliar fungicides if severe.",
28
+ # Fallback
29
+ "Unknown": "No specific treatment available."
30
+ }
31
+
32
+ def classify_image_vit(image):
33
+ # Convert to PIL Image in case input is numpy
34
+ if not isinstance(image, Image.Image):
35
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
36
+ inputs = vit_processor(images=image, return_tensors="pt")
37
+ outputs = vit_model(**inputs)
38
+ logits = outputs.logits
39
+ predicted_class_idx = logits.argmax(-1).item()
40
+
41
+ # Predicted label
42
+ predicted_label = vit_model.config.id2label.get(predicted_class_idx, "Unknown")
43
+ # Example: If your id2label from HF is something like "corn diseased" or "rice healthy",
44
+ # match it to the dictionary key for treatments (above). For demonstration:
45
+ treatment_text = vit_label_treatment.get(predicted_label, "No specific treatment available.")
46
+ return predicted_label, treatment_text
47
+
48
+
49
+ # ============== TensorFlow Model (plant_model_v5-beta.h5) ==============
50
+ # Load the model
51
+ keras_model = tf.keras.models.load_model('plant_model_v5-beta.h5')
52
+
53
+ # Define the class names
54
+ class_names = {
55
+ 0: 'Apple___Apple_scab',
56
+ 1: 'Apple___Black_rot',
57
+ 2: 'Apple___Cedar_apple_rust',
58
+ 3: 'Apple___healthy',
59
+ 4: 'Not a plant',
60
+ 5: 'Blueberry___healthy',
61
+ 6: 'Cherry___Powdery_mildew',
62
+ 7: 'Cherry___healthy',
63
+ 8: 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
64
+ 9: 'Corn___Common_rust',
65
+ 10: 'Corn___Northern_Leaf_Blight',
66
+ 11: 'Corn___healthy',
67
+ 12: 'Grape___Black_rot',
68
+ 13: 'Grape___Esca_(Black_Measles)',
69
+ 14: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
70
+ 15: 'Grape___healthy',
71
+ 16: 'Orange___Haunglongbing_(Citrus_greening)',
72
+ 17: 'Peach___Bacterial_spot',
73
+ 18: 'Peach___healthy',
74
+ 19: 'Pepper,_bell___Bacterial_spot',
75
+ 20: 'Pepper,_bell___healthy',
76
+ 21: 'Potato___Early_blight',
77
+ 22: 'Potato___Late_blight',
78
+ 23: 'Potato___healthy',
79
+ 24: 'Raspberry___healthy',
80
+ 25: 'Soybean___healthy',
81
+ 26: 'Squash___Powdery_mildew',
82
+ 27: 'Strawberry___Leaf_scorch',
83
+ 28: 'Strawberry___healthy',
84
+ 29: 'Tomato___Bacterial_spot',
85
+ 30: 'Tomato___Early_blight',
86
+ 31: 'Tomato___Late_blight',
87
+ 32: 'Tomato___Leaf_Mold',
88
+ 33: 'Tomato___Septoria_leaf_spot',
89
+ 34: 'Tomato___Spider_mites Two-spotted_spider_mite',
90
+ 35: 'Tomato___Target_Spot',
91
+ 36: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
92
+ 37: 'Tomato___Tomato_mosaic_virus',
93
+ 38: 'Tomato___healthy'
94
+ }
95
+
96
+ # Example dictionary of "treatments" for some classes
97
+ keras_treatments = {
98
+ 'Apple___Apple_scab': "Remove fallen leaves, apply fungicides.",
99
+ 'Apple___Black_rot': "Prune out dead branches; apply copper-based fungicide.",
100
+ 'Corn___Common_rust': "Use resistant hybrids; apply fungicide if needed.",
101
+ 'Corn___Cercospora_leaf_spot Gray_leaf_spot': "Rotate crops; use foliar fungicides.",
102
+ 'Potato___Early_blight': "Use certified seeds; apply preventative fungicides.",
103
+ 'Tomato___Target_Spot': "Use resistant varieties and mulches to reduce disease.",
104
+ # Fallback:
105
+ 'Unknown': "No specific treatment available."
106
+ }
107
+
108
+ def edge_and_cut(img, threshold1, threshold2):
109
+ emb_img = img.copy()
110
+ edges = cv2.Canny(img, threshold1, threshold2)
111
+ edge_coors = []
112
+ for i in range(edges.shape[0]):
113
+ for j in range(edges.shape[1]):
114
+ if edges[i][j] != 0:
115
+ edge_coors.append((i, j))
116
+
117
+ if len(edge_coors) == 0:
118
+ return emb_img
119
+
120
+ row_min = edge_coors[np.argsort([coor[0] for coor in edge_coors])[0]][0]
121
+ row_max = edge_coors[np.argsort([coor[0] for coor in edge_coors])[-1]][0]
122
+ col_min = edge_coors[np.argsort([coor[1] for coor in edge_coors])[0]][1]
123
+ col_max = edge_coors[np.argsort([coor[1] for coor in edge_coors])[-1]][1]
124
+ new_img = img[row_min:row_max, col_min:col_max]
125
+
126
+ # Simple bounding box in white
127
+ emb_color = np.array([255], dtype=np.uint8)
128
+ emb_img[row_min-10:row_min+10, col_min:col_max] = emb_color
129
+ emb_img[row_max-10:row_max+10, col_min:col_max] = emb_color
130
+ emb_img[row_min:row_max, col_min-10:col_min+10] = emb_color
131
+ emb_img[row_min:row_max, col_max-10:col_max+10] = emb_color
132
+
133
+ return emb_img
134
+
135
+ def classify_and_visualize_keras(image):
136
+ # Preprocess the image
137
+ img_array = tf.image.resize(image, [256, 256])
138
+ img_array = tf.expand_dims(img_array, 0) / 255.0
139
+
140
+ # Make a prediction
141
+ prediction = keras_model.predict(img_array)
142
+ predicted_class_idx = tf.argmax(prediction[0], axis=-1).numpy()
143
+ confidence = np.max(prediction[0])
144
+
145
+ # Obtain the predicted label
146
+ predicted_label = class_names.get(predicted_class_idx, "Unknown")
147
+
148
+ if confidence < 0.60:
149
+ class_name = "Uncertain / Not in dataset"
150
+ bounded_image = image
151
+ treatment_text = "No treatment recommendation (uncertain prediction)."
152
+ else:
153
+ class_name = predicted_label
154
+ bounded_image = edge_and_cut(image, 200, 400)
155
+ treatment_text = keras_treatments.get(predicted_label, "No specific treatment available.")
156
+
157
+ return class_name, float(confidence), bounded_image, treatment_text
158
+
159
+
160
+ # ============== Combined Gradio App ==============
161
+ def main_model_selector(model_choice, image):
162
+ """
163
+ Dispatch function based on user choice of model:
164
+ - 'Vit-model (Corn/Potato/Rice/Wheat)' -> use classify_image_vit
165
+ - 'Keras-model (Apple/Blueberry/Cherry/etc.)' -> use classify_and_visualize_keras
166
+ """
167
+ if image is None:
168
+ return "No image provided.", None, None, None
169
+
170
+ if model_choice == "ViT (Corn, Potato, Rice, Wheat)":
171
+ # Return: label, treatment
172
+ predicted_label, treatment_text = classify_image_vit(image)
173
+ # For consistency with the Keras model outputs,
174
+ # we'll keep placeholders for confidence & bounding box
175
+ return predicted_label, None, image, treatment_text
176
+
177
+ elif model_choice == "Keras (Apple, Blueberry, Cherry, etc.)":
178
+ # Return: class_name, confidence, bounded_image, treatment_text
179
+ class_name, confidence, bounded_image, treatment_text = classify_and_visualize_keras(image)
180
+ return class_name, confidence, bounded_image, treatment_text
181
+
182
+ else:
183
+ return "Invalid model choice.", None, None, None
184
+
185
+
186
+ # Create Gradio interface
187
+ with gr.Blocks() as demo:
188
+ gr.Markdown("# **Plant Disease Detection**")
189
+ gr.Markdown(
190
+ "Select which model you want to use, then upload an image to see the prediction, "
191
+ "confidence (if applicable), bounding box (if applicable), and a suggested treatment."
192
+ )
193
+
194
+ with gr.Row():
195
+ model_choice = gr.Radio(
196
+ choices=["ViT (Corn, Potato, Rice, Wheat)", "Keras (Apple, Blueberry, Cherry, etc.)"],
197
+ value="Keras (Apple, Blueberry, Cherry, etc.)",
198
+ label="Select Model"
199
+ )
200
+
201
+ with gr.Row():
202
+ inp_image = gr.Image(type="numpy", label="Upload Leaf Image")
203
+
204
+ # Outputs
205
+ with gr.Row():
206
+ out_label = gr.Textbox(label="Predicted Class")
207
+ out_confidence = gr.Textbox(label="Confidence (If Available)")
208
+ out_bounded_image = gr.Image(label="Visualization (If Available)")
209
+ out_treatment = gr.Textbox(label="Treatment Recommendation")
210
+
211
+ # Button
212
+ btn = gr.Button("Classify")
213
+
214
+ # Function binding
215
+ btn.click(
216
+ fn=main_model_selector,
217
+ inputs=[model_choice, inp_image],
218
+ outputs=[out_label, out_confidence, out_bounded_image, out_treatment]
219
+ )
220
+
221
+ # Provide some example images
222
+ gr.Examples(
223
+ examples=[
224
+ ["Keras (Apple, Blueberry, Cherry, etc.)", "corn.jpg"],
225
+ ["Keras (Apple, Blueberry, Cherry, etc.)", "grot.jpg"],
226
+ ["Keras (Apple, Blueberry, Cherry, etc.)", "Potato___Early_blight.jpg"],
227
+ ["Keras (Apple, Blueberry, Cherry, etc.)", "Tomato___Target_Spot.jpg"],
228
+ ["ViT (Corn, Potato, Rice, Wheat)", "corn.jpg"],
229
+ ],
230
+ inputs=[model_choice, inp_image]
231
+ )
232
+
233
+ demo.launch(share=True)