iamomtiwari commited on
Commit
1867c2e
·
verified ·
1 Parent(s): 70ad7af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
4
+
5
+ # Load the model and feature extractor
6
+ feature_extractor = ViTFeatureExtractor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
7
+ model = ViTForImageClassification.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
8
+
9
+ # Define prediction function
10
+ def predict(image):
11
+ image = Image.fromarray(image) # Convert image from numpy array to PIL Image
12
+ inputs = feature_extractor(images=image, return_tensors="pt")
13
+ outputs = model(**inputs)
14
+ logits = outputs.logits
15
+ predicted_class_idx = logits.argmax(-1).item()
16
+ return model.config.id2label[predicted_class_idx]
17
+
18
+ # Create Gradio interface
19
+ iface = gr.Interface(
20
+ fn=predict,
21
+ inputs=gr.inputs.Image(type="numpy"), # Input type as a numpy array
22
+ outputs="text",
23
+ title="Crop Disease Detection",
24
+ description="Upload an image of a crop leaf to detect diseases."
25
+ )
26
+
27
+ iface.launch()