iamomtiwari commited on
Commit
bb13609
·
verified ·
1 Parent(s): e4ea851

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -1,22 +1,27 @@
1
  import gradio as gr
2
  from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  from PIL import Image
4
- import torch
5
 
6
- # Load model and feature extractor from Hugging Face Hub
7
  model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
8
  feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
9
 
10
- # Define inference function
11
- def classify_image(img):
12
- inputs = feature_extractor(images=img, return_tensors="pt")
13
- with torch.no_grad():
14
- outputs = model(**inputs)
15
- logits = outputs.logits
16
- predicted_class_idx = torch.argmax(logits, dim=-1).item()
17
- class_name = model.config.id2label[predicted_class_idx]
18
- return class_name
 
 
 
 
 
 
 
 
19
 
20
- # Create Gradio interface
21
- interface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="text")
22
  interface.launch()
 
1
  import gradio as gr
2
  from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  from PIL import Image
 
4
 
5
+ # Load model and feature extractor
6
  model = ViTForImageClassification.from_pretrained("iamomtiwari/VITPEST")
7
  feature_extractor = ViTFeatureExtractor.from_pretrained("iamomtiwari/VITPEST")
8
 
9
+ def predict(image):
10
+ # Preprocess image
11
+ inputs = feature_extractor(images=image, return_tensors="pt")
12
+ outputs = model(**inputs)
13
+ logits = outputs.logits
14
+ predicted_class_idx = logits.argmax(-1).item()
15
+ class_labels = model.config.id2label # Assuming you have a class label mapping
16
+ return class_labels[predicted_class_idx]
17
+
18
+ # Gradio interface
19
+ interface = gr.Interface(
20
+ fn=predict,
21
+ inputs=gr.Image(type="pil"),
22
+ outputs="text",
23
+ title="Crop Disease Classifier",
24
+ description="Upload an image of a crop leaf to classify the disease."
25
+ )
26
 
 
 
27
  interface.launch()