kadirnar commited on
Commit
c018310
·
1 Parent(s): b57130c

Upload 3 files

Browse files
Files changed (3) hide show
  1. dataloader.py +55 -0
  2. download.py +17 -0
  3. istanbul_unet.py +21 -0
dataloader.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as albu
2
+ import numpy as np
3
+ import cv2
4
+ import os
5
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
6
+
7
+
8
+ class Dataset:
9
+ def __init__(
10
+ self,
11
+ image_path,
12
+ augmentation=None,
13
+ preprocessing=None,
14
+ ):
15
+ self.pil_image = image_path
16
+ self.augmentation = augmentation
17
+ self.preprocessing = preprocessing
18
+
19
+ def get(self):
20
+ # pil image > numpy array
21
+ image = np.array(self.pil_image)
22
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23
+
24
+ # apply augmentations
25
+ if self.augmentation:
26
+ sample = self.augmentation(image=image)
27
+ image = sample['image']
28
+
29
+ # apply preprocessing
30
+ if self.preprocessing:
31
+ sample = self.preprocessing(image=image)
32
+ image = sample['image']
33
+
34
+ return image
35
+
36
+
37
+ def get_validation_augmentation():
38
+ """Add paddings to make image shape divisible by 32"""
39
+ test_transform = [
40
+ albu.PadIfNeeded(384, 480)
41
+ ]
42
+ return albu.Compose(test_transform)
43
+
44
+
45
+ def to_tensor(x, **kwargs):
46
+ return x.transpose(2, 0, 1).astype('float32')
47
+
48
+
49
+ def get_preprocessing(preprocessing_fn):
50
+
51
+ _transform = [
52
+ albu.Lambda(image=preprocessing_fn),
53
+ albu.Lambda(image=to_tensor),
54
+ ]
55
+ return albu.Compose(_transform)
download.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def attempt_download_from_hub(repo_id, hf_token=None):
2
+ # https://github.com/fcakyon/yolov5-pip/blob/main/yolov5/utils/downloads.py
3
+ from huggingface_hub import hf_hub_download, list_repo_files
4
+ from huggingface_hub.utils._errors import RepositoryNotFoundError
5
+ from huggingface_hub.utils._validators import HFValidationError
6
+ try:
7
+ repo_files = list_repo_files(repo_id=repo_id, repo_type='model', token=hf_token)
8
+ model_file = [f for f in repo_files if f.endswith('.pth')][0]
9
+ file = hf_hub_download(
10
+ repo_id=repo_id,
11
+ filename=model_file,
12
+ repo_type='model',
13
+ token=hf_token,
14
+ )
15
+ return file
16
+ except (RepositoryNotFoundError, HFValidationError):
17
+ return None
istanbul_unet.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from download import attempt_download_from_hub
2
+ import segmentation_models_pytorch as smp
3
+ from dataloader import *
4
+ import torch
5
+
6
+
7
+ def unet_prediction(input_path, model_path):
8
+ model_path = attempt_download_from_hub(model_path)
9
+ best_model = torch.load(model_path)
10
+ preprocessing_fn = smp.encoders.get_preprocessing_fn('efficientnet-b6', 'imagenet')
11
+
12
+ test_dataset = Dataset(input_path, augmentation=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))
13
+ image = test_dataset.get()
14
+
15
+ x_tensor = torch.from_numpy(image).to("cuda").unsqueeze(0)
16
+ pr_mask = best_model.predict(x_tensor)
17
+ pr_mask = (pr_mask.squeeze().cpu().numpy().round())*255
18
+
19
+ # Save the predicted mask
20
+ cv2.imwrite("output.png", pr_mask)
21
+ return 'output.png'