wayandadang commited on
Commit
ce309f9
·
1 Parent(s): 01f0a3d

update app.py

Browse files
Files changed (4) hide show
  1. app.py +32 -18
  2. desktop.ini +2 -0
  3. report_model.png +0 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -13,33 +13,49 @@ class CNNKAN(nn.Module):
13
  def __init__(self):
14
  super(CNNKAN, self).__init__()
15
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
 
16
  self.pool1 = nn.MaxPool2d(2)
17
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
 
18
  self.pool2 = nn.MaxPool2d(2)
19
- self.kan1 = KANLinear(64 * 50 * 50, 256)
 
 
 
 
20
  self.kan2 = KANLinear(256, 1)
21
 
22
  def forward(self, x):
23
- x = F.selu(self.conv1(x))
24
  x = self.pool1(x)
25
- x = F.selu(self.conv2(x))
26
  x = self.pool2(x)
 
 
27
  x = x.view(x.size(0), -1)
 
28
  x = self.kan1(x)
 
29
  x = self.kan2(x)
30
  return x
31
 
32
- # Assuming the model weights are saved in 'model.pth'
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- model = CNNKAN().to(device)
35
- model.load_state_dict(torch.load('weights/model_weights_KAN1.pth', map_location=device))
36
- model.eval()
 
 
 
 
 
37
 
38
- # Define image transformations
39
- transform = transforms.Compose([
40
- transforms.Resize((200, 200)),
41
- transforms.ToTensor()
42
- ])
 
43
 
44
  # Streamlit app
45
  st.title("Image Classification with CNN-KAN")
@@ -48,10 +64,8 @@ st.sidebar.title("Upload Images")
48
  uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
49
  image_url = st.sidebar.text_input("Or enter image URL...")
50
 
51
- def load_image_from_url(url):
52
- response = requests.get(url)
53
- img = Image.open(BytesIO(response.content)).convert('RGB')
54
- return img
55
 
56
  img = None
57
 
@@ -66,7 +80,7 @@ elif image_url:
66
  if img is not None:
67
  st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
68
  if st.button('Predict'):
69
- img_tensor = transform(img).unsqueeze(0).to(device)
70
 
71
  with torch.no_grad():
72
  output = model(img_tensor)
 
13
  def __init__(self):
14
  super(CNNKAN, self).__init__()
15
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
16
+ self.bn1 = nn.BatchNorm2d(32)
17
  self.pool1 = nn.MaxPool2d(2)
18
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
19
+ self.bn2 = nn.BatchNorm2d(64)
20
  self.pool2 = nn.MaxPool2d(2)
21
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
22
+ self.bn3 = nn.BatchNorm2d(128)
23
+ self.pool3 = nn.MaxPool2d(2)
24
+ self.dropout = nn.Dropout(0.5)
25
+ self.kan1 = KANLinear(128 * 25 * 25, 256)
26
  self.kan2 = KANLinear(256, 1)
27
 
28
  def forward(self, x):
29
+ x = F.selu(self.bn1(self.conv1(x)))
30
  x = self.pool1(x)
31
+ x = F.selu(self.bn2(self.conv2(x)))
32
  x = self.pool2(x)
33
+ x = F.selu(self.bn3(self.conv3(x)))
34
+ x = self.pool3(x)
35
  x = x.view(x.size(0), -1)
36
+ x = self.dropout(x)
37
  x = self.kan1(x)
38
+ x = self.dropout(x)
39
  x = self.kan2(x)
40
  return x
41
 
42
+ def load_model(weights_path, device):
43
+ model = CNNKAN().to(device)
44
+ model.load_state_dict(torch.load(weights_path, map_location=device))
45
+ model.eval()
46
+ return model
47
+
48
+ def load_image_from_url(url):
49
+ response = requests.get(url)
50
+ img = Image.open(BytesIO(response.content)).convert('RGB')
51
+ return img
52
 
53
+ def preprocess_image(image):
54
+ transform = transforms.Compose([
55
+ transforms.Resize((200, 200)),
56
+ transforms.ToTensor()
57
+ ])
58
+ return transform(image).unsqueeze(0)
59
 
60
  # Streamlit app
61
  st.title("Image Classification with CNN-KAN")
 
64
  uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
65
  image_url = st.sidebar.text_input("Or enter image URL...")
66
 
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ model = load_model('weights/best_model_weights_KAN.pth', device)
 
 
69
 
70
  img = None
71
 
 
80
  if img is not None:
81
  st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
82
  if st.button('Predict'):
83
+ img_tensor = preprocess_image(img).to(device)
84
 
85
  with torch.no_grad():
86
  output = model(img_tensor)
desktop.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [LocalizedFileNames]
2
+ Screenshot 2024-06-05 110741.png=@Screenshot 2024-06-05 110741.png,0
report_model.png ADDED
requirements.txt CHANGED
@@ -75,4 +75,5 @@ ultralytics==8.1.30
75
  urllib3==2.2.1
76
  watchdog==4.0.0
77
  pafy
78
- youtube-dl
 
 
75
  urllib3==2.2.1
76
  watchdog==4.0.0
77
  pafy
78
+ youtube-dl
79
+ optuna