manikandan9943114590 commited on
Commit
45be811
·
verified ·
1 Parent(s): 4758f10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -55
app.py CHANGED
@@ -1,55 +1,55 @@
1
- import torch
2
- import gradio as gr
3
- from src.model import DRModel
4
- from torchvision import transforms as T
5
-
6
- CHECKPOINT_PATH = "artifacts/dr-model.ckpt"
7
- model = DRModel.load_from_checkpoint(CHECKPOINT_PATH, map_location="cpu")
8
- model.eval()
9
-
10
- labels = {
11
- 0: "No DR",
12
- 1: "Mild",
13
- 2: "Moderate",
14
- 3: "Severe",
15
- 4: "Proliferative DR",
16
- }
17
-
18
- transform = T.Compose(
19
- [
20
- T.Resize((224, 224)),
21
- T.ToTensor(),
22
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
- ]
24
- )
25
-
26
-
27
- # Define the prediction function
28
- def predict(input_img):
29
- input_img = transform(input_img).unsqueeze(0)
30
- with torch.no_grad():
31
- prediction = torch.nn.functional.softmax(model(input_img)[0], dim=0)
32
- confidences = {labels[i]: float(prediction[i]) for i in labels}
33
- return confidences
34
-
35
-
36
- # Set up the Gradio app interface
37
- dr_app = gr.Interface(
38
- fn=predict,
39
- inputs=gr.Image(type="pil"),
40
- outputs=gr.Label(),
41
- title="Diabetic Retinopathy Detection App",
42
- description="Welcome to our Diabetic Retinopathy Detection App! \
43
- This app utilizes deep learning models to detect diabetic retinopathy in retinal images.\
44
- Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.",
45
- examples=[
46
- "data/sample/10_left.jpeg",
47
- "data/sample/10_right.jpeg",
48
- "data/sample/15_left.jpeg",
49
- "data/sample/16_right.jpeg",
50
- ],
51
- )
52
-
53
- # Run the Gradio app
54
- if __name__ == "__main__":
55
- dr_app.launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ from src.model import DRModel
4
+ from torchvision import transforms as T
5
+
6
+ CHECKPOINT_PATH = "dr-model.ckpt"
7
+ model = DRModel.load_from_checkpoint(CHECKPOINT_PATH, map_location="cpu")
8
+ model.eval()
9
+
10
+ labels = {
11
+ 0: "No DR",
12
+ 1: "Mild",
13
+ 2: "Moderate",
14
+ 3: "Severe",
15
+ 4: "Proliferative DR",
16
+ }
17
+
18
+ transform = T.Compose(
19
+ [
20
+ T.Resize((224, 224)),
21
+ T.ToTensor(),
22
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
+ ]
24
+ )
25
+
26
+
27
+ # Define the prediction function
28
+ def predict(input_img):
29
+ input_img = transform(input_img).unsqueeze(0)
30
+ with torch.no_grad():
31
+ prediction = torch.nn.functional.softmax(model(input_img)[0], dim=0)
32
+ confidences = {labels[i]: float(prediction[i]) for i in labels}
33
+ return confidences
34
+
35
+
36
+ # Set up the Gradio app interface
37
+ dr_app = gr.Interface(
38
+ fn=predict,
39
+ inputs=gr.Image(type="pil"),
40
+ outputs=gr.Label(),
41
+ title="Diabetic Retinopathy Detection App",
42
+ description="Welcome to our Diabetic Retinopathy Detection App! \
43
+ This app utilizes deep learning models to detect diabetic retinopathy in retinal images.\
44
+ Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.",
45
+ examples=[
46
+ "data/sample/10_left.jpeg",
47
+ "data/sample/10_right.jpeg",
48
+ "data/sample/15_left.jpeg",
49
+ "data/sample/16_right.jpeg",
50
+ ],
51
+ )
52
+
53
+ # Run the Gradio app
54
+ if __name__ == "__main__":
55
+ dr_app.launch()