wayandadang
commited on
Commit
·
ce309f9
1
Parent(s):
01f0a3d
update app.py
Browse files- app.py +32 -18
- desktop.ini +2 -0
- report_model.png +0 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -13,33 +13,49 @@ class CNNKAN(nn.Module):
|
|
13 |
def __init__(self):
|
14 |
super(CNNKAN, self).__init__()
|
15 |
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
|
|
|
16 |
self.pool1 = nn.MaxPool2d(2)
|
17 |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
|
|
18 |
self.pool2 = nn.MaxPool2d(2)
|
19 |
-
self.
|
|
|
|
|
|
|
|
|
20 |
self.kan2 = KANLinear(256, 1)
|
21 |
|
22 |
def forward(self, x):
|
23 |
-
x = F.selu(self.conv1(x))
|
24 |
x = self.pool1(x)
|
25 |
-
x = F.selu(self.conv2(x))
|
26 |
x = self.pool2(x)
|
|
|
|
|
27 |
x = x.view(x.size(0), -1)
|
|
|
28 |
x = self.kan1(x)
|
|
|
29 |
x = self.kan2(x)
|
30 |
return x
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
model
|
35 |
-
model.
|
36 |
-
model
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
-
transform = transforms.Compose([
|
40 |
-
|
41 |
-
|
42 |
-
])
|
|
|
43 |
|
44 |
# Streamlit app
|
45 |
st.title("Image Classification with CNN-KAN")
|
@@ -48,10 +64,8 @@ st.sidebar.title("Upload Images")
|
|
48 |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
|
49 |
image_url = st.sidebar.text_input("Or enter image URL...")
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
img = Image.open(BytesIO(response.content)).convert('RGB')
|
54 |
-
return img
|
55 |
|
56 |
img = None
|
57 |
|
@@ -66,7 +80,7 @@ elif image_url:
|
|
66 |
if img is not None:
|
67 |
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
|
68 |
if st.button('Predict'):
|
69 |
-
img_tensor =
|
70 |
|
71 |
with torch.no_grad():
|
72 |
output = model(img_tensor)
|
|
|
13 |
def __init__(self):
|
14 |
super(CNNKAN, self).__init__()
|
15 |
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
|
16 |
+
self.bn1 = nn.BatchNorm2d(32)
|
17 |
self.pool1 = nn.MaxPool2d(2)
|
18 |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
19 |
+
self.bn2 = nn.BatchNorm2d(64)
|
20 |
self.pool2 = nn.MaxPool2d(2)
|
21 |
+
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
22 |
+
self.bn3 = nn.BatchNorm2d(128)
|
23 |
+
self.pool3 = nn.MaxPool2d(2)
|
24 |
+
self.dropout = nn.Dropout(0.5)
|
25 |
+
self.kan1 = KANLinear(128 * 25 * 25, 256)
|
26 |
self.kan2 = KANLinear(256, 1)
|
27 |
|
28 |
def forward(self, x):
|
29 |
+
x = F.selu(self.bn1(self.conv1(x)))
|
30 |
x = self.pool1(x)
|
31 |
+
x = F.selu(self.bn2(self.conv2(x)))
|
32 |
x = self.pool2(x)
|
33 |
+
x = F.selu(self.bn3(self.conv3(x)))
|
34 |
+
x = self.pool3(x)
|
35 |
x = x.view(x.size(0), -1)
|
36 |
+
x = self.dropout(x)
|
37 |
x = self.kan1(x)
|
38 |
+
x = self.dropout(x)
|
39 |
x = self.kan2(x)
|
40 |
return x
|
41 |
|
42 |
+
def load_model(weights_path, device):
|
43 |
+
model = CNNKAN().to(device)
|
44 |
+
model.load_state_dict(torch.load(weights_path, map_location=device))
|
45 |
+
model.eval()
|
46 |
+
return model
|
47 |
+
|
48 |
+
def load_image_from_url(url):
|
49 |
+
response = requests.get(url)
|
50 |
+
img = Image.open(BytesIO(response.content)).convert('RGB')
|
51 |
+
return img
|
52 |
|
53 |
+
def preprocess_image(image):
|
54 |
+
transform = transforms.Compose([
|
55 |
+
transforms.Resize((200, 200)),
|
56 |
+
transforms.ToTensor()
|
57 |
+
])
|
58 |
+
return transform(image).unsqueeze(0)
|
59 |
|
60 |
# Streamlit app
|
61 |
st.title("Image Classification with CNN-KAN")
|
|
|
64 |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
|
65 |
image_url = st.sidebar.text_input("Or enter image URL...")
|
66 |
|
67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
+
model = load_model('weights/best_model_weights_KAN.pth', device)
|
|
|
|
|
69 |
|
70 |
img = None
|
71 |
|
|
|
80 |
if img is not None:
|
81 |
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
|
82 |
if st.button('Predict'):
|
83 |
+
img_tensor = preprocess_image(img).to(device)
|
84 |
|
85 |
with torch.no_grad():
|
86 |
output = model(img_tensor)
|
desktop.ini
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[LocalizedFileNames]
|
2 |
+
Screenshot 2024-06-05 110741.png=@Screenshot 2024-06-05 110741.png,0
|
report_model.png
ADDED
requirements.txt
CHANGED
@@ -75,4 +75,5 @@ ultralytics==8.1.30
|
|
75 |
urllib3==2.2.1
|
76 |
watchdog==4.0.0
|
77 |
pafy
|
78 |
-
youtube-dl
|
|
|
|
75 |
urllib3==2.2.1
|
76 |
watchdog==4.0.0
|
77 |
pafy
|
78 |
+
youtube-dl
|
79 |
+
optuna
|