yalouini commited on
Commit
f4698e1
·
1 Parent(s): 2f564aa

Create app.py

Browse files

Start the app.py. WIP.

Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import timm
5
+ import torchvision.transforms as transforms
6
+ import pytorch_lightning as pl
7
+ from PIL import Image
8
+ import numpy as np
9
+ from torch import nn
10
+ import smp
11
+
12
+ # The accompanying inference app
13
+
14
+ PATHS = ['1.tiff', '2.tiff']
15
+
16
+
17
+ NUM_CLASSES = len(CLASSES)
18
+
19
+ IDS_TO_CLASSES_DICT = dict(zip(list(range(NUM_CLASSES)), CLASSES))
20
+
21
+
22
+ MODEL_NAME = "se_resne"
23
+ MODEL_PATH = "model.ckpt"
24
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
25
+ TRANSFORM = transforms.Compose([transforms.ToTensor(),
26
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
27
+
28
+
29
+
30
+ BACKBONE = ""
31
+ IN_CHANNELS = ""
32
+ CLASSES = ""
33
+ # TODO: path to weights?
34
+ WEIGHTS = ""
35
+
36
+ class VesuviusModel(nn.Module):
37
+ def __init__(self, weight=None):
38
+ super().__init__()
39
+ self.cfg = cfg
40
+
41
+ self.encoder = smp.Unet(
42
+ encoder_name=BACKBONE,
43
+ encoder_weights=WEIGHTS,
44
+ in_channels=IN_CHANNELS,
45
+ classes=CLASSES,
46
+ activation=None,
47
+ )
48
+
49
+ def forward(self, image):
50
+ output = self.encoder(image)
51
+ output = output.squeeze(-1)
52
+ return output
53
+
54
+
55
+ def load_weights_into_model(model_name: str, model_path: str) -> nn.Module:
56
+ model = VesuviusModel(model_name)
57
+ state_dict = torch.load(model_path, map_location=DEVICE)["state_dict"]
58
+ model.load_state_dict(state_dict)
59
+ return model
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+ model = load_weights_into_model(MODEL_NAME, MODEL_PATH)
69
+ model.to(DEVICE)
70
+ model.eval()
71
+
72
+ img_path = st.selectbox('Select an image to segment', PATHS)
73
+
74
+ st.write('You have selected:', img_path)
75
+ img = Image.open(img_path)
76
+
77
+ st.image(img, caption='Selected image to segment')
78
+
79
+ np_img = np.array(img)
80
+
81
+ input_batch = TRANSFORM(np_img[:, :, :3]).unsqueeze(0).to(DEVICE)
82
+
83
+ with st.spinner("Segmenting the image in progress..."):
84
+
85
+
86
+ with torch.no_grad():
87
+ # TODO: Finish...
88
+ prediction = model(input_batch).cpu()
89
+ print(prediction)
90
+