iamomtiwari commited on
Commit
5f5775f
·
verified ·
1 Parent(s): bb13609

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  from PIL import Image
4
 
@@ -6,22 +7,38 @@ from PIL import Image
6
  model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
7
  feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def predict(image):
10
- # Preprocess image
11
  inputs = feature_extractor(images=image, return_tensors="pt")
12
- outputs = model(**inputs)
13
- logits = outputs.logits
14
- predicted_class_idx = logits.argmax(-1).item()
15
- class_labels = model.config.id2label # Assuming you have a class label mapping
16
- return class_labels[predicted_class_idx]
17
-
18
- # Gradio interface
19
- interface = gr.Interface(
20
- fn=predict,
21
- inputs=gr.Image(type="pil"),
22
- outputs="text",
23
- title="Crop Disease Classifier",
24
- description="Upload an image of a crop leaf to classify the disease."
25
- )
26
 
 
 
27
  interface.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import ViTForImageClassification, ViTFeatureExtractor
4
  from PIL import Image
5
 
 
7
  model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
8
  feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
9
 
10
+ # Define class labels and treatment advice
11
+ class_labels = {
12
+ "Corn___Common_Rust": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants.",
13
+ "Corn___Gray_Leaf_Spot": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed.",
14
+ "Corn___Healthy": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests.",
15
+ "Corn___Northern_Leaf_Blight": "Remove and destroy infected plant debris, apply fungicides, and rotate crops.",
16
+ "Rice___Brown_Spot": "Use resistant varieties, improve field drainage, and apply fungicides if necessary.",
17
+ "Rice___Healthy": "Maintain proper irrigation, fertilization, and pest control measures.",
18
+ "Rice___Leaf_Blast": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management.",
19
+ "Rice___Neck_Blast": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear.",
20
+ "Wheat___Brown_Rust": "Apply fungicides and practice crop rotation with non-host crops.",
21
+ "Wheat___Healthy": "Continue with good management practices, including proper fertilization and weed control.",
22
+ "Wheat___Yellow_Rust": "Use resistant varieties, apply fungicides, and rotate crops.",
23
+ "Sugarcane__Red_Rot": "Plant resistant varieties and ensure good drainage.",
24
+ "Sugarcane__Healthy": "Maintain healthy soil conditions and proper irrigation.",
25
+ "Sugarcane__Bacterial Blight": "Use disease-free planting material, practice crop rotation, and destroy infected plants."
26
+ }
27
+
28
+ # Mapping label indices to class labels
29
+ labels_list = list(class_labels.keys())
30
+
31
+ # Inference function
32
  def predict(image):
 
33
  inputs = feature_extractor(images=image, return_tensors="pt")
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ predicted_class_idx = outputs.logits.argmax(-1).item()
37
+ predicted_label = labels_list[predicted_class_idx]
38
+ treatment_advice = class_labels[predicted_label]
39
+
40
+ return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
 
 
 
 
 
 
 
41
 
42
+ # Create Gradio Interface
43
+ interface = gr.Interface(fn=predict, inputs="image", outputs="text")
44
  interface.launch()