Abrar20's picture
Update app.py
537af0d verified
import gradio as gr
import numpy as np
import cv2
import tensorflow as tf
import torch
from PIL import Image
# ============== HF Transformers / ViT Model ==============
from transformers import ViTImageProcessor, ViTForImageClassification
# ----------- 1. Load the ViT model & processor ------------
vit_processor = ViTImageProcessor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
vit_model = ViTForImageClassification.from_pretrained(
'wambugu1738/crop_leaf_diseases_vit',
ignore_mismatched_sizes=True
)
vit_label_treatment = {
"Corn___Common_rust": "Use recommended fungicides and ensure crop rotation.",
"Corn___Cercospora_leaf_spot": "Apply foliar fungicides; ensure good field sanitation.",
"Potato___Early_blight": "Apply preventive fungicides; remove infected debris.",
"Potato___Late_blight": "Use certified seed tubers; fungicide sprays when conditions favor disease.",
"Rice___Leaf_blight": "Use resistant rice varieties, maintain field hygiene.",
"Wheat___Leaf_rust": "Plant resistant wheat varieties, apply foliar fungicides if severe.",
# Fallback
"Unknown": "No specific treatment available."
}
def classify_image_vit(image):
if not isinstance(image, Image.Image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
inputs = vit_processor(images=image, return_tensors="pt")
outputs = vit_model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Predicted label
predicted_label = vit_model.config.id2label.get(predicted_class_idx, "Unknown")
treatment_text = vit_label_treatment.get(predicted_label, "No specific treatment available.")
return predicted_label, treatment_text
# ============== TensorFlow Model (plant_model_v5-beta.h5) ==============
# Load the model
keras_model = tf.keras.models.load_model('plant_model_v5-beta.h5')
# Define the class names
class_names = {
0: 'Apple___Apple_scab',
1: 'Apple___Black_rot',
2: 'Apple___Cedar_apple_rust',
3: 'Apple___healthy',
4: 'Not a plant',
5: 'Blueberry___healthy',
6: 'Cherry___Powdery_mildew',
7: 'Cherry___healthy',
8: 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
9: 'Corn___Common_rust',
10: 'Corn___Northern_Leaf_Blight',
11: 'Corn___healthy',
12: 'Grape___Black_rot',
13: 'Grape___Esca_(Black_Measles)',
14: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
15: 'Grape___healthy',
16: 'Orange___Haunglongbing_(Citrus_greening)',
17: 'Peach___Bacterial_spot',
18: 'Peach___healthy',
19: 'Pepper,_bell___Bacterial_spot',
20: 'Pepper,_bell___healthy',
21: 'Potato___Early_blight',
22: 'Potato___Late_blight',
23: 'Potato___healthy',
24: 'Raspberry___healthy',
25: 'Soybean___healthy',
26: 'Squash___Powdery_mildew',
27: 'Strawberry___Leaf_scorch',
28: 'Strawberry___healthy',
29: 'Tomato___Bacterial_spot',
30: 'Tomato___Early_blight',
31: 'Tomato___Late_blight',
32: 'Tomato___Leaf_Mold',
33: 'Tomato___Septoria_leaf_spot',
34: 'Tomato___Spider_mites Two-spotted_spider_mite',
35: 'Tomato___Target_Spot',
36: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
37: 'Tomato___Tomato_mosaic_virus',
38: 'Tomato___healthy'
}
# Example dictionary of "treatments" for some classes
keras_treatments = {
'Apple___Apple_scab': "Remove fallen leaves and prune infected branches. Apply fungicides containing captan or myclobutanil.",
'Apple___Black_rot': "Prune out dead branches. Spray copper-based fungicide during early fruit development.",
'Apple___Cedar_apple_rust': "Remove nearby juniper trees. Apply fungicides before bud break.",
'Apple___healthy': "No action required. The plant is healthy.",
'Blueberry___healthy': "No action required. The plant is healthy.",
'Cherry___Powdery_mildew': "Apply sulfur-based fungicide. Ensure good air circulation around the plant.",
'Cherry___healthy': "No action required. The plant is healthy.",
'Corn___Cercospora_leaf_spot Gray_leaf_spot': "Rotate crops to avoid build-up of pathogens. Use resistant hybrids and apply foliar fungicides.",
'Corn___Common_rust': "Plant rust-resistant hybrids. Apply fungicides at the first sign of rust.",
'Corn___Northern_Leaf_Blight': "Use resistant varieties and apply fungicides when lesions are observed.",
'Corn___healthy': "No action required. The plant is healthy.",
'Grape___Black_rot': "Remove and destroy infected leaves and fruits. Apply fungicides containing myclobutanil or captan.",
'Grape___Esca_(Black_Measles)': "Prune and destroy infected wood. Apply fungicides during the growing season.",
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)': "Maintain good air circulation. Spray protective fungicides like mancozeb.",
'Grape___healthy': "No action required. The plant is healthy.",
'Orange___Haunglongbing_(Citrus_greening)': "Remove and destroy infected trees. Control psyllid vectors with insecticides.",
'Peach___Bacterial_spot': "Apply copper-based bactericides. Use resistant varieties and avoid overhead irrigation.",
'Peach___healthy': "No action required. The plant is healthy.",
'Pepper,_bell___Bacterial_spot': "Apply copper-based sprays. Use certified seeds and avoid overhead irrigation.",
'Pepper,_bell___healthy': "No action required. The plant is healthy.",
'Potato___Early_blight': "Use certified seeds and apply preventative fungicides like chlorothalonil.",
'Potato___Late_blight': "Plant disease-free tubers and use fungicides containing metalaxyl.",
'Potato___healthy': "No action required. The plant is healthy.",
'Raspberry___healthy': "No action required. The plant is healthy.",
'Soybean___healthy': "No action required. The plant is healthy.",
'Squash___Powdery_mildew': "Use sulfur-based fungicides and ensure good ventilation.",
'Strawberry___Leaf_scorch': "Remove infected leaves. Apply fungicides containing myclobutanil.",
'Strawberry___healthy': "No action required. The plant is healthy.",
'Tomato___Bacterial_spot': "Apply copper-based sprays. Avoid overhead watering.",
'Tomato___Early_blight': "Prune infected leaves and apply fungicides containing chlorothalonil or mancozeb.",
'Tomato___Late_blight': "Remove infected plants. Apply fungicides containing chlorothalonil or metalaxyl.",
'Tomato___Leaf_Mold': "Ensure good ventilation and apply fungicides like mancozeb.",
'Tomato___Septoria_leaf_spot': "Remove infected leaves and apply fungicides containing chlorothalonil.",
'Tomato___Spider_mites Two-spotted_spider_mite': "Spray insecticidal soap or neem oil. Maintain humidity levels.",
'Tomato___Target_Spot': "Use resistant varieties. Apply fungicides containing chlorothalonil.",
'Tomato___Tomato_Yellow_Leaf_Curl_Virus': "Remove infected plants. Use resistant varieties and control whitefly vectors.",
'Tomato___Tomato_mosaic_virus': "Remove infected plants and disinfect tools. Use resistant seed varieties.",
'Tomato___healthy': "No action required. The plant is healthy.",
'Unknown': "No specific treatment available."
}
def edge_and_cut(img, threshold1, threshold2):
emb_img = img.copy()
edges = cv2.Canny(img, threshold1, threshold2)
edge_coors = []
for i in range(edges.shape[0]):
for j in range(edges.shape[1]):
if edges[i][j] != 0:
edge_coors.append((i, j))
if len(edge_coors) == 0:
return emb_img
row_min = edge_coors[np.argsort([coor[0] for coor in edge_coors])[0]][0]
row_max = edge_coors[np.argsort([coor[0] for coor in edge_coors])[-1]][0]
col_min = edge_coors[np.argsort([coor[1] for coor in edge_coors])[0]][1]
col_max = edge_coors[np.argsort([coor[1] for coor in edge_coors])[-1]][1]
new_img = img[row_min:row_max, col_min:col_max]
# Simple bounding box in white
emb_color = np.array([255], dtype=np.uint8)
emb_img[row_min-10:row_min+10, col_min:col_max] = emb_color
emb_img[row_max-10:row_max+10, col_min:col_max] = emb_color
emb_img[row_min:row_max, col_min-10:col_min+10] = emb_color
emb_img[row_min:row_max, col_max-10:col_max+10] = emb_color
return emb_img
def classify_and_visualize_keras(image):
# Preprocess the image
img_array = tf.image.resize(image, [256, 256])
img_array = tf.expand_dims(img_array, 0) / 255.0
# Make a prediction
prediction = keras_model.predict(img_array)
predicted_class_idx = tf.argmax(prediction[0], axis=-1).numpy()
confidence = np.max(prediction[0])
# Obtain the predicted label
predicted_label = class_names.get(predicted_class_idx, "Unknown")
if confidence < 0.60:
class_name = "Uncertain / Not in dataset"
bounded_image = image
treatment_text = "No treatment recommendation (uncertain prediction)."
else:
class_name = predicted_label
bounded_image = edge_and_cut(image, 200, 400)
treatment_text = keras_treatments.get(predicted_label, "No specific treatment available.")
return class_name, float(confidence), bounded_image, treatment_text
# ============== Combined Gradio App ==============
def main_model_selector(model_choice, image):
"""
Dispatch function based on user choice of model:
- 'Vit-model (Corn/Potato/Rice/Wheat)' -> use classify_image_vit
- 'Keras-model (Apple/Blueberry/Cherry/etc.)' -> use classify_and_visualize_keras
"""
if image is None:
return "No image provided.", None, None, None
if model_choice == "ViT (Corn, Potato, Rice, Wheat)":
# Return: label, treatment
predicted_label, treatment_text = classify_image_vit(image)
# For consistency with the Keras model outputs,
# we'll keep placeholders for confidence & bounding box
return predicted_label, None, image, treatment_text
elif model_choice == "Keras (Apple, Blueberry, Cherry, etc.)":
# Return: class_name, confidence, bounded_image, treatment_text
class_name, confidence, bounded_image, treatment_text = classify_and_visualize_keras(image)
return class_name, confidence, bounded_image, treatment_text
else:
return "Invalid model choice.", None, None, None
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# **Plant Disease Detection**")
gr.Markdown(
"Select which model you want to use, then upload an image to see the prediction, "
"confidence (if applicable), bounding box (if applicable), and a suggested treatment."
)
with gr.Row():
model_choice = gr.Radio(
choices=["ViT (Corn, Potato, Rice, Wheat)", "Keras (Apple, Blueberry, Cherry, etc.)"],
value="Keras (Apple, Blueberry, Cherry, etc.)",
label="Select Model"
)
with gr.Row():
inp_image = gr.Image(type="numpy", label="Upload Leaf Image")
# Outputs
with gr.Row():
out_label = gr.Textbox(label="Predicted Class")
out_confidence = gr.Textbox(label="Confidence (If Available)")
out_bounded_image = gr.Image(label="Visualization (If Available)")
out_treatment = gr.Textbox(label="Treatment Recommendation")
# Button
btn = gr.Button("Classify")
# Function binding
btn.click(
fn=main_model_selector,
inputs=[model_choice, inp_image],
outputs=[out_label, out_confidence, out_bounded_image, out_treatment]
)
# Provide some example images
gr.Examples(
examples=[
["Keras (Apple, Blueberry, Cherry, etc.)", "corn.jpg"],
["Keras (Apple, Blueberry, Cherry, etc.)", "grot.jpg"],
["Keras (Apple, Blueberry, Cherry, etc.)", "Potato___Early_blight.jpg"],
["Keras (Apple, Blueberry, Cherry, etc.)", "Tomato___Target_Spot.jpg"],
["ViT (Corn, Potato, Rice, Wheat)", "corn.jpg"],
],
inputs=[model_choice, inp_image]
)
demo.launch(share=True)