turhancan97 commited on
Commit
929f451
·
1 Parent(s): b7d4bcf

app file created

Browse files
Files changed (6) hide show
  1. app.py +101 -0
  2. images/cat.jpg +0 -0
  3. images/dog.jpg +0 -0
  4. model.py +220 -0
  5. requirements.txt +16 -0
  6. vit-t-mae-pretrain.pt +3 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import torchvision
5
+ from PIL import Image
6
+ import numpy as np
7
+ import random
8
+ from einops import rearrange
9
+ import matplotlib.pyplot as plt
10
+ from torchvision.transforms import v2
11
+
12
+ from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor
13
+
14
+ path = [['images/cat.jpg'], ['images/dog.jpg']]
15
+ model_name = "vit-t-mae-pretrain.pt"
16
+ model = torch.load(model_name, map_location='cpu')
17
+
18
+ model.eval()
19
+ device = torch.device("cpu")
20
+ model.to(device)
21
+
22
+ transform = v2.Compose([
23
+ v2.Resize((32, 32)),
24
+ v2.ToTensor(),
25
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
26
+ ])
27
+
28
+ # Load and Preprocess the Image
29
+ def load_image(image_path, transform):
30
+ img = Image.open(image_path).convert('RGB')
31
+ # transform = Compose([ToTensor(), Normalize(0.5, 0.5), Resize((32, 32))])
32
+ img = transform(img).unsqueeze(0) # Add batch dimension
33
+ return img
34
+
35
+ def show_image(img, title):
36
+ img = rearrange(img, "c h w -> h w c")
37
+ img = (img.cpu().detach().numpy() + 1) / 2 # Normalize to [0, 1]
38
+
39
+ plt.imshow(img)
40
+ plt.axis('off')
41
+ plt.title(title)
42
+
43
+ # Visualize a Single Image
44
+ def visualize_single_image(image_path, image_name, model, device):
45
+ img = load_image(image_path, transform).to(device)
46
+
47
+ # Run inference
48
+ model.eval()
49
+ with torch.no_grad():
50
+ predicted_img, mask = model(img)
51
+
52
+ # Convert the tensor back to a displayable image
53
+ # masked image
54
+ im_masked = img * (1 - mask)
55
+
56
+ # MAE reconstruction pasted with visible patches
57
+ im_paste = img * (1 - mask) + predicted_img * mask
58
+
59
+ # make the plt figure larger
60
+ plt.figure(figsize=(12, 4))
61
+
62
+ plt.subplot(1, 4, 1)
63
+ show_image(img[0], "original")
64
+
65
+ plt.subplot(1, 4, 2)
66
+ show_image(im_masked[0], "masked")
67
+
68
+ plt.subplot(1, 4, 3)
69
+ show_image(predicted_img[0], "reconstruction")
70
+
71
+ plt.subplot(1, 4, 4)
72
+ show_image(im_paste[0], "reconstruction + visible")
73
+
74
+ plt.tight_layout()
75
+
76
+ return plt
77
+
78
+ # Example Usage
79
+ image_path = 'images/dog.jpg' # Replace with the actual path to your image
80
+ # take the string after the last '/' as the image name
81
+ image_name = image_path.split('/')[-1].split('.')[0]
82
+ visualize_single_image(image_path, image_name, model, device)
83
+
84
+ inputs_image = [
85
+ gr.components.Image(type="filepath", label="Input Image"),
86
+ ]
87
+
88
+ outputs_image = [
89
+ gr.outputs.Image(type="plot", label="Output Image"),
90
+ ]
91
+
92
+ gr.Interface(
93
+ fn=visualize_single_image,
94
+ inputs=inputs_image,
95
+ outputs=outputs_image,
96
+ title="MAE-ViT Image Reconstruction",
97
+ description="This is a demo of the MAE-ViT model for image reconstruction.",
98
+ allow_flagging=False,
99
+ allow_screenshot=False,
100
+ allow_remote_access=False,
101
+ ).launch()
images/cat.jpg ADDED
images/dog.jpg ADDED
model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # MAE: https://github.com/IcarusWizard/MAE
4
+ # --------------------------------------------------------
5
+
6
+ import torch
7
+ import timm
8
+ import numpy as np
9
+
10
+ from einops import repeat, rearrange
11
+ from einops.layers.torch import Rearrange
12
+
13
+ from timm.models.layers import trunc_normal_
14
+ from timm.models.vision_transformer import Block
15
+
16
+ def random_indexes(size : int):
17
+ forward_indexes = np.arange(size)
18
+ np.random.shuffle(forward_indexes)
19
+ backward_indexes = np.argsort(forward_indexes)
20
+ return forward_indexes, backward_indexes
21
+
22
+ def take_indexes(sequences, indexes):
23
+ return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))
24
+
25
+ class PatchShuffle(torch.nn.Module):
26
+ def __init__(self, ratio) -> None:
27
+ super().__init__()
28
+ self.ratio = ratio
29
+
30
+ def forward(self, patches : torch.Tensor):
31
+ T, B, C = patches.shape
32
+ remain_T = int(T * (1 - self.ratio))
33
+
34
+ indexes = [random_indexes(T) for _ in range(B)]
35
+ forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
36
+ backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
37
+
38
+ patches = take_indexes(patches, forward_indexes)
39
+ patches = patches[:remain_T]
40
+
41
+ return patches, forward_indexes, backward_indexes
42
+
43
+ class MAE_Encoder(torch.nn.Module):
44
+ def __init__(self,
45
+ image_size=32,
46
+ patch_size=2,
47
+ emb_dim=192,
48
+ num_layer=12,
49
+ num_head=3,
50
+ mask_ratio=0.75,
51
+ ) -> None:
52
+ super().__init__()
53
+
54
+ self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
55
+ self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
56
+ self.shuffle = PatchShuffle(mask_ratio)
57
+
58
+ self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
59
+
60
+ self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
61
+
62
+ self.layer_norm = torch.nn.LayerNorm(emb_dim)
63
+
64
+ self.init_weight()
65
+
66
+ def init_weight(self):
67
+ trunc_normal_(self.cls_token, std=.02)
68
+ trunc_normal_(self.pos_embedding, std=.02)
69
+
70
+ def forward(self, img):
71
+ patches = self.patchify(img)
72
+ patches = rearrange(patches, 'b c h w -> (h w) b c')
73
+ patches = patches + self.pos_embedding
74
+
75
+ patches, forward_indexes, backward_indexes = self.shuffle(patches)
76
+
77
+ patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
78
+ patches = rearrange(patches, 't b c -> b t c')
79
+ features = self.layer_norm(self.transformer(patches))
80
+ features = rearrange(features, 'b t c -> t b c')
81
+
82
+ return features, backward_indexes
83
+
84
+ class MAE_Decoder(torch.nn.Module):
85
+ def __init__(self,
86
+ image_size=32,
87
+ patch_size=2,
88
+ emb_dim=192,
89
+ num_layer=4,
90
+ num_head=3,
91
+ ) -> None:
92
+ super().__init__()
93
+
94
+ self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
95
+ self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
96
+
97
+ self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
98
+
99
+ self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
100
+ self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
101
+
102
+ self.init_weight()
103
+
104
+ def init_weight(self):
105
+ trunc_normal_(self.mask_token, std=.02)
106
+ trunc_normal_(self.pos_embedding, std=.02)
107
+
108
+ def forward(self, features, backward_indexes):
109
+ T = features.shape[0]
110
+ backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
111
+ features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
112
+ features = take_indexes(features, backward_indexes)
113
+ features = features + self.pos_embedding
114
+
115
+ features = rearrange(features, 't b c -> b t c')
116
+ features = self.transformer(features)
117
+ features = rearrange(features, 'b t c -> t b c')
118
+ features = features[1:] # remove global feature
119
+
120
+ patches = self.head(features)
121
+ mask = torch.zeros_like(patches)
122
+ mask[T-1:] = 1
123
+ mask = take_indexes(mask, backward_indexes[1:] - 1)
124
+ img = self.patch2img(patches)
125
+ mask = self.patch2img(mask)
126
+
127
+ return img, mask
128
+
129
+ class MAE_ViT(torch.nn.Module):
130
+ def __init__(self,
131
+ image_size=32,
132
+ patch_size=2,
133
+ emb_dim=192,
134
+ encoder_layer=12,
135
+ encoder_head=3,
136
+ decoder_layer=4,
137
+ decoder_head=3,
138
+ mask_ratio=0.75,
139
+ ) -> None:
140
+ super().__init__()
141
+
142
+ self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
143
+ self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)
144
+
145
+ def forward(self, img):
146
+ features, backward_indexes = self.encoder(img)
147
+ predicted_img, mask = self.decoder(features, backward_indexes)
148
+ return predicted_img, mask
149
+
150
+ class ViT_Classifier(torch.nn.Module):
151
+ '''
152
+ A simple image classification task acts as a head for ViT, allowing fine-tuning on downstream tasks.
153
+ We didn't directly use the MAE_ViT encoder because we need to add a classification head.
154
+ The Masked Autoencoder uses only some patches as input, which means it lacks the global information of the image,
155
+ making it unsuitable for classification.
156
+ '''
157
+ def __init__(self, encoder : MAE_Encoder, dropout_p, num_classes=10) -> None:
158
+ super().__init__()
159
+ self.dropout_p = dropout_p
160
+ self.cls_token = encoder.cls_token
161
+ self.pos_embedding = encoder.pos_embedding
162
+ self.patchify = encoder.patchify
163
+ self.transformer = encoder.transformer
164
+ self.layer_norm = encoder.layer_norm
165
+ self.dropout = torch.nn.Dropout(dropout_p) # Add dropout layer
166
+ self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)
167
+
168
+ def forward(self, img):
169
+ patches = self.patchify(img)
170
+ patches = rearrange(patches, 'b c h w -> (h w) b c')
171
+ patches = patches + self.pos_embedding
172
+ patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
173
+ patches = rearrange(patches, 't b c -> b t c')
174
+ features = self.layer_norm(self.transformer(patches))
175
+ # t is the number of patches, b is the batch size, c is the number of features
176
+ features = rearrange(features, 'b t c -> t b c')
177
+ if self.dropout_p > 0:
178
+ features = self.dropout(features) # Apply dropout before the final head
179
+ logits = self.head(features[0]) # only use the cls token
180
+ return logits
181
+
182
+ class MAE_Encoder_FeatureExtractor(torch.nn.Module):
183
+ '''
184
+ A feature extractor that extracts features from the encoder of the Masked Autoencoder.
185
+ '''
186
+ def __init__(self, encoder : MAE_Encoder) -> None:
187
+ super().__init__()
188
+ self.cls_token = encoder.cls_token
189
+ self.pos_embedding = encoder.pos_embedding
190
+ self.patchify = encoder.patchify
191
+ self.transformer = encoder.transformer
192
+ self.layer_norm = encoder.layer_norm
193
+
194
+ def forward(self, img):
195
+ patches = self.patchify(img)
196
+ patches = rearrange(patches, 'b c h w -> (h w) b c')
197
+ patches = patches + self.pos_embedding
198
+ patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
199
+ patches = rearrange(patches, 't b c -> b t c')
200
+ features = self.layer_norm(self.transformer(patches))
201
+ # t is the number of patches, b is the batch size, c is the number of features
202
+ features = rearrange(features, 'b t c -> t b c')
203
+ return features
204
+
205
+
206
+ if __name__ == '__main__':
207
+ shuffle = PatchShuffle(0.75)
208
+ a = torch.rand(16, 2, 10)
209
+ b, forward_indexes, backward_indexes = shuffle(a)
210
+ print(b.shape)
211
+
212
+ img = torch.rand(2, 3, 32, 32)
213
+ encoder = MAE_Encoder()
214
+ decoder = MAE_Decoder()
215
+ features, backward_indexes = encoder(img)
216
+ print(forward_indexes.shape)
217
+ predicted_img, mask = decoder(features, backward_indexes)
218
+ print(predicted_img.shape)
219
+ loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
220
+ print(loss)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python=3.8
2
+ torch
3
+ torchvision
4
+ tensorboard
5
+ scikit-learn
6
+ matplotlib
7
+ numpy
8
+ einops
9
+ timm==0.4.12
10
+ tqdm
11
+ omega
12
+ pyyaml
13
+ opencv-python
14
+ wandb
15
+ icecream
16
+ torchinfo
vit-t-mae-pretrain.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852a6a0806c42a8c725b0de82cd0e7b59d7d79ad21f8e012bc599eedcce15375
3
+ size 28972154