iamomtiwari commited on
Commit
acf9f46
·
verified ·
1 Parent(s): a5bca54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import ViTForImageClassification, ViTFeatureExtractor
4
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
5
  from PIL import Image
6
 
7
- # Load crop model (ViT)
8
  model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
9
  feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
10
 
@@ -12,6 +12,10 @@ feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
12
  fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
13
  fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
14
 
 
 
 
 
15
  # Define class labels with treatment advice
16
  class_labels = {
17
  1: {"label": "Stage Corn___Common_Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
@@ -34,7 +38,7 @@ class_labels = {
34
  labels_list = [class_labels[i]["label"] for i in range(1, 15)]
35
 
36
  # Inference function
37
- def predict(image):
38
  # First, use the crop disease model (ViT)
39
  inputs = feature_extractor(images=image, return_tensors="pt")
40
  with torch.no_grad():
@@ -44,10 +48,22 @@ def predict(image):
44
  # Check if the predicted label corresponds to a crop disease
45
  predicted_label = labels_list[predicted_class_idx]
46
 
47
- # Ask user if the prediction is correct
 
 
 
 
 
 
 
 
 
 
 
 
48
  if predicted_class_idx < len(class_labels): # It's a crop disease
49
  treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
50
- return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}\n\nIs this prediction correct? (yes/no)"
51
  else:
52
  # If not a crop disease, use the fallback model (ResNet50) for general object detection
53
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
@@ -57,8 +73,12 @@ def predict(image):
57
 
58
  # Get the fallback prediction label
59
  fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
60
- return f"Fallback Prediction (Not a Crop): {fallback_label}\n\nIs this prediction correct? (yes/no)"
61
 
62
  # Create Gradio Interface
63
- interface = gr.Interface(fn=predict, inputs="image", outputs="text")
 
 
 
 
64
  interface.launch()
 
4
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
5
  from PIL import Image
6
 
7
+ # Load crop disease model (ViT)
8
  model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
9
  feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
10
 
 
12
  fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
13
  fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
14
 
15
+ # Load additional ViT model (221k model) for a different classification if the user feedback is "no"
16
+ vit_221k_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
17
+ vit_221k_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
18
+
19
  # Define class labels with treatment advice
20
  class_labels = {
21
  1: {"label": "Stage Corn___Common_Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
 
38
  labels_list = [class_labels[i]["label"] for i in range(1, 15)]
39
 
40
  # Inference function
41
+ def predict(image, feedback):
42
  # First, use the crop disease model (ViT)
43
  inputs = feature_extractor(images=image, return_tensors="pt")
44
  with torch.no_grad():
 
48
  # Check if the predicted label corresponds to a crop disease
49
  predicted_label = labels_list[predicted_class_idx]
50
 
51
+ # If the feedback is "no", switch to ViT 221k model for a different class prediction
52
+ if feedback == "no":
53
+ # Use ViT 221k model
54
+ inputs_vit_221k = vit_221k_feature_extractor(images=image, return_tensors="pt")
55
+ with torch.no_grad():
56
+ outputs_vit_221k = vit_221k_model(**inputs_vit_221k)
57
+ predicted_class_idx_vit_221k = outputs_vit_221k.logits.argmax(-1).item()
58
+
59
+ # Get the ViT 221k prediction label
60
+ vit_221k_label = vit_221k_model.config.id2label[predicted_class_idx_vit_221k]
61
+ return f"Fallback ViT 221k Prediction: {vit_221k_label}"
62
+
63
+ # If feedback is "yes", return the initial disease prediction and treatment advice
64
  if predicted_class_idx < len(class_labels): # It's a crop disease
65
  treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
66
+ return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
67
  else:
68
  # If not a crop disease, use the fallback model (ResNet50) for general object detection
69
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
 
73
 
74
  # Get the fallback prediction label
75
  fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
76
+ return f"Fallback Prediction (Not a Crop): {fallback_label}"
77
 
78
  # Create Gradio Interface
79
+ interface = gr.Interface(
80
+ fn=predict,
81
+ inputs=["image", gr.Radio(["yes", "no"], label="Is the prediction correct?")],
82
+ outputs="text"
83
+ )
84
  interface.launch()