iamomtiwari commited on
Commit
cca355f
·
verified ·
1 Parent(s): 49337b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -33,19 +33,21 @@ class_labels = {
33
  # Mapping label indices to class labels
34
  labels_list = [class_labels[i]["label"] for i in range(1, 15)]
35
 
 
 
 
36
  # Inference function
37
- def predict(image, feedback):
38
  # First, use the crop disease model (ViT)
39
  inputs = feature_extractor(images=image, return_tensors="pt")
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
- predicted_class_idx = outputs.logits.argmax(-1).item()
43
-
44
- # Check if the predicted label corresponds to a crop disease
45
- predicted_label = labels_list[predicted_class_idx]
46
 
47
- # If the feedback is "no", switch to the fallback ResNet-50 model
48
- if feedback == "no":
49
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
50
  with torch.no_grad():
51
  outputs_fallback = fallback_model(**inputs_fallback)
@@ -53,20 +55,20 @@ def predict(image, feedback):
53
 
54
  # Get the fallback prediction label
55
  fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
56
- return f"Fallback Prediction (ResNet-50): {fallback_label}"
57
 
58
- # If feedback is "yes", return the initial disease prediction and treatment advice
59
- if predicted_class_idx < len(class_labels): # It's a crop disease
60
- treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
61
- return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
62
- else:
63
- return f"Disease: {predicted_label}"
64
 
65
  # Create Gradio Interface
66
  interface = gr.Interface(
67
  fn=predict,
68
- inputs=["image", gr.Radio(["yes", "no"], label="Is it a crop like corn, rice, sugarcane, wheat??")],
69
- outputs="text"
 
 
70
  )
71
 
72
  if __name__ == "__main__":
 
33
  # Mapping label indices to class labels
34
  labels_list = [class_labels[i]["label"] for i in range(1, 15)]
35
 
36
+ # Confidence threshold for ViT model
37
+ CONFIDENCE_THRESHOLD = 0.5
38
+
39
  # Inference function
40
+ def predict(image):
41
  # First, use the crop disease model (ViT)
42
  inputs = feature_extractor(images=image, return_tensors="pt")
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
+ logits = outputs.logits
46
+ predicted_class_idx = logits.argmax(-1).item()
47
+ confidence = torch.softmax(logits, dim=-1)[0, predicted_class_idx].item()
 
48
 
49
+ # If confidence is below the threshold, directly switch to ResNet50
50
+ if confidence < CONFIDENCE_THRESHOLD:
51
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
52
  with torch.no_grad():
53
  outputs_fallback = fallback_model(**inputs_fallback)
 
55
 
56
  # Get the fallback prediction label
57
  fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
58
+ return f"Low confidence in ViT prediction. ResNet-50 Prediction: {fallback_label}"
59
 
60
+ # If confidence is above the threshold, return the ViT prediction and treatment advice
61
+ predicted_label = labels_list[predicted_class_idx]
62
+ treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
63
+ return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
 
 
64
 
65
  # Create Gradio Interface
66
  interface = gr.Interface(
67
  fn=predict,
68
+ inputs="image",
69
+ outputs="text",
70
+ title="Crop Disease Detection",
71
+ description="Upload an image of a crop plant to detect diseases. If confidence is low, ResNet-50 will classify the image."
72
  )
73
 
74
  if __name__ == "__main__":