Spaces:
Sleeping
Sleeping
turhancan97
commited on
Commit
·
929f451
1
Parent(s):
b7d4bcf
app file created
Browse files- app.py +101 -0
- images/cat.jpg +0 -0
- images/dog.jpg +0 -0
- model.py +220 -0
- requirements.txt +16 -0
- vit-t-mae-pretrain.pt +3 -0
app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
from einops import rearrange
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
|
12 |
+
from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor
|
13 |
+
|
14 |
+
path = [['images/cat.jpg'], ['images/dog.jpg']]
|
15 |
+
model_name = "vit-t-mae-pretrain.pt"
|
16 |
+
model = torch.load(model_name, map_location='cpu')
|
17 |
+
|
18 |
+
model.eval()
|
19 |
+
device = torch.device("cpu")
|
20 |
+
model.to(device)
|
21 |
+
|
22 |
+
transform = v2.Compose([
|
23 |
+
v2.Resize((32, 32)),
|
24 |
+
v2.ToTensor(),
|
25 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
26 |
+
])
|
27 |
+
|
28 |
+
# Load and Preprocess the Image
|
29 |
+
def load_image(image_path, transform):
|
30 |
+
img = Image.open(image_path).convert('RGB')
|
31 |
+
# transform = Compose([ToTensor(), Normalize(0.5, 0.5), Resize((32, 32))])
|
32 |
+
img = transform(img).unsqueeze(0) # Add batch dimension
|
33 |
+
return img
|
34 |
+
|
35 |
+
def show_image(img, title):
|
36 |
+
img = rearrange(img, "c h w -> h w c")
|
37 |
+
img = (img.cpu().detach().numpy() + 1) / 2 # Normalize to [0, 1]
|
38 |
+
|
39 |
+
plt.imshow(img)
|
40 |
+
plt.axis('off')
|
41 |
+
plt.title(title)
|
42 |
+
|
43 |
+
# Visualize a Single Image
|
44 |
+
def visualize_single_image(image_path, image_name, model, device):
|
45 |
+
img = load_image(image_path, transform).to(device)
|
46 |
+
|
47 |
+
# Run inference
|
48 |
+
model.eval()
|
49 |
+
with torch.no_grad():
|
50 |
+
predicted_img, mask = model(img)
|
51 |
+
|
52 |
+
# Convert the tensor back to a displayable image
|
53 |
+
# masked image
|
54 |
+
im_masked = img * (1 - mask)
|
55 |
+
|
56 |
+
# MAE reconstruction pasted with visible patches
|
57 |
+
im_paste = img * (1 - mask) + predicted_img * mask
|
58 |
+
|
59 |
+
# make the plt figure larger
|
60 |
+
plt.figure(figsize=(12, 4))
|
61 |
+
|
62 |
+
plt.subplot(1, 4, 1)
|
63 |
+
show_image(img[0], "original")
|
64 |
+
|
65 |
+
plt.subplot(1, 4, 2)
|
66 |
+
show_image(im_masked[0], "masked")
|
67 |
+
|
68 |
+
plt.subplot(1, 4, 3)
|
69 |
+
show_image(predicted_img[0], "reconstruction")
|
70 |
+
|
71 |
+
plt.subplot(1, 4, 4)
|
72 |
+
show_image(im_paste[0], "reconstruction + visible")
|
73 |
+
|
74 |
+
plt.tight_layout()
|
75 |
+
|
76 |
+
return plt
|
77 |
+
|
78 |
+
# Example Usage
|
79 |
+
image_path = 'images/dog.jpg' # Replace with the actual path to your image
|
80 |
+
# take the string after the last '/' as the image name
|
81 |
+
image_name = image_path.split('/')[-1].split('.')[0]
|
82 |
+
visualize_single_image(image_path, image_name, model, device)
|
83 |
+
|
84 |
+
inputs_image = [
|
85 |
+
gr.components.Image(type="filepath", label="Input Image"),
|
86 |
+
]
|
87 |
+
|
88 |
+
outputs_image = [
|
89 |
+
gr.outputs.Image(type="plot", label="Output Image"),
|
90 |
+
]
|
91 |
+
|
92 |
+
gr.Interface(
|
93 |
+
fn=visualize_single_image,
|
94 |
+
inputs=inputs_image,
|
95 |
+
outputs=outputs_image,
|
96 |
+
title="MAE-ViT Image Reconstruction",
|
97 |
+
description="This is a demo of the MAE-ViT model for image reconstruction.",
|
98 |
+
allow_flagging=False,
|
99 |
+
allow_screenshot=False,
|
100 |
+
allow_remote_access=False,
|
101 |
+
).launch()
|
images/cat.jpg
ADDED
images/dog.jpg
ADDED
model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# References:
|
3 |
+
# MAE: https://github.com/IcarusWizard/MAE
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import timm
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from einops import repeat, rearrange
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
|
13 |
+
from timm.models.layers import trunc_normal_
|
14 |
+
from timm.models.vision_transformer import Block
|
15 |
+
|
16 |
+
def random_indexes(size : int):
|
17 |
+
forward_indexes = np.arange(size)
|
18 |
+
np.random.shuffle(forward_indexes)
|
19 |
+
backward_indexes = np.argsort(forward_indexes)
|
20 |
+
return forward_indexes, backward_indexes
|
21 |
+
|
22 |
+
def take_indexes(sequences, indexes):
|
23 |
+
return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))
|
24 |
+
|
25 |
+
class PatchShuffle(torch.nn.Module):
|
26 |
+
def __init__(self, ratio) -> None:
|
27 |
+
super().__init__()
|
28 |
+
self.ratio = ratio
|
29 |
+
|
30 |
+
def forward(self, patches : torch.Tensor):
|
31 |
+
T, B, C = patches.shape
|
32 |
+
remain_T = int(T * (1 - self.ratio))
|
33 |
+
|
34 |
+
indexes = [random_indexes(T) for _ in range(B)]
|
35 |
+
forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
|
36 |
+
backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
|
37 |
+
|
38 |
+
patches = take_indexes(patches, forward_indexes)
|
39 |
+
patches = patches[:remain_T]
|
40 |
+
|
41 |
+
return patches, forward_indexes, backward_indexes
|
42 |
+
|
43 |
+
class MAE_Encoder(torch.nn.Module):
|
44 |
+
def __init__(self,
|
45 |
+
image_size=32,
|
46 |
+
patch_size=2,
|
47 |
+
emb_dim=192,
|
48 |
+
num_layer=12,
|
49 |
+
num_head=3,
|
50 |
+
mask_ratio=0.75,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
|
55 |
+
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
|
56 |
+
self.shuffle = PatchShuffle(mask_ratio)
|
57 |
+
|
58 |
+
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
|
59 |
+
|
60 |
+
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
|
61 |
+
|
62 |
+
self.layer_norm = torch.nn.LayerNorm(emb_dim)
|
63 |
+
|
64 |
+
self.init_weight()
|
65 |
+
|
66 |
+
def init_weight(self):
|
67 |
+
trunc_normal_(self.cls_token, std=.02)
|
68 |
+
trunc_normal_(self.pos_embedding, std=.02)
|
69 |
+
|
70 |
+
def forward(self, img):
|
71 |
+
patches = self.patchify(img)
|
72 |
+
patches = rearrange(patches, 'b c h w -> (h w) b c')
|
73 |
+
patches = patches + self.pos_embedding
|
74 |
+
|
75 |
+
patches, forward_indexes, backward_indexes = self.shuffle(patches)
|
76 |
+
|
77 |
+
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
|
78 |
+
patches = rearrange(patches, 't b c -> b t c')
|
79 |
+
features = self.layer_norm(self.transformer(patches))
|
80 |
+
features = rearrange(features, 'b t c -> t b c')
|
81 |
+
|
82 |
+
return features, backward_indexes
|
83 |
+
|
84 |
+
class MAE_Decoder(torch.nn.Module):
|
85 |
+
def __init__(self,
|
86 |
+
image_size=32,
|
87 |
+
patch_size=2,
|
88 |
+
emb_dim=192,
|
89 |
+
num_layer=4,
|
90 |
+
num_head=3,
|
91 |
+
) -> None:
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
|
95 |
+
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
|
96 |
+
|
97 |
+
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
|
98 |
+
|
99 |
+
self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
|
100 |
+
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def init_weight(self):
|
105 |
+
trunc_normal_(self.mask_token, std=.02)
|
106 |
+
trunc_normal_(self.pos_embedding, std=.02)
|
107 |
+
|
108 |
+
def forward(self, features, backward_indexes):
|
109 |
+
T = features.shape[0]
|
110 |
+
backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
|
111 |
+
features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
|
112 |
+
features = take_indexes(features, backward_indexes)
|
113 |
+
features = features + self.pos_embedding
|
114 |
+
|
115 |
+
features = rearrange(features, 't b c -> b t c')
|
116 |
+
features = self.transformer(features)
|
117 |
+
features = rearrange(features, 'b t c -> t b c')
|
118 |
+
features = features[1:] # remove global feature
|
119 |
+
|
120 |
+
patches = self.head(features)
|
121 |
+
mask = torch.zeros_like(patches)
|
122 |
+
mask[T-1:] = 1
|
123 |
+
mask = take_indexes(mask, backward_indexes[1:] - 1)
|
124 |
+
img = self.patch2img(patches)
|
125 |
+
mask = self.patch2img(mask)
|
126 |
+
|
127 |
+
return img, mask
|
128 |
+
|
129 |
+
class MAE_ViT(torch.nn.Module):
|
130 |
+
def __init__(self,
|
131 |
+
image_size=32,
|
132 |
+
patch_size=2,
|
133 |
+
emb_dim=192,
|
134 |
+
encoder_layer=12,
|
135 |
+
encoder_head=3,
|
136 |
+
decoder_layer=4,
|
137 |
+
decoder_head=3,
|
138 |
+
mask_ratio=0.75,
|
139 |
+
) -> None:
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
|
143 |
+
self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)
|
144 |
+
|
145 |
+
def forward(self, img):
|
146 |
+
features, backward_indexes = self.encoder(img)
|
147 |
+
predicted_img, mask = self.decoder(features, backward_indexes)
|
148 |
+
return predicted_img, mask
|
149 |
+
|
150 |
+
class ViT_Classifier(torch.nn.Module):
|
151 |
+
'''
|
152 |
+
A simple image classification task acts as a head for ViT, allowing fine-tuning on downstream tasks.
|
153 |
+
We didn't directly use the MAE_ViT encoder because we need to add a classification head.
|
154 |
+
The Masked Autoencoder uses only some patches as input, which means it lacks the global information of the image,
|
155 |
+
making it unsuitable for classification.
|
156 |
+
'''
|
157 |
+
def __init__(self, encoder : MAE_Encoder, dropout_p, num_classes=10) -> None:
|
158 |
+
super().__init__()
|
159 |
+
self.dropout_p = dropout_p
|
160 |
+
self.cls_token = encoder.cls_token
|
161 |
+
self.pos_embedding = encoder.pos_embedding
|
162 |
+
self.patchify = encoder.patchify
|
163 |
+
self.transformer = encoder.transformer
|
164 |
+
self.layer_norm = encoder.layer_norm
|
165 |
+
self.dropout = torch.nn.Dropout(dropout_p) # Add dropout layer
|
166 |
+
self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)
|
167 |
+
|
168 |
+
def forward(self, img):
|
169 |
+
patches = self.patchify(img)
|
170 |
+
patches = rearrange(patches, 'b c h w -> (h w) b c')
|
171 |
+
patches = patches + self.pos_embedding
|
172 |
+
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
|
173 |
+
patches = rearrange(patches, 't b c -> b t c')
|
174 |
+
features = self.layer_norm(self.transformer(patches))
|
175 |
+
# t is the number of patches, b is the batch size, c is the number of features
|
176 |
+
features = rearrange(features, 'b t c -> t b c')
|
177 |
+
if self.dropout_p > 0:
|
178 |
+
features = self.dropout(features) # Apply dropout before the final head
|
179 |
+
logits = self.head(features[0]) # only use the cls token
|
180 |
+
return logits
|
181 |
+
|
182 |
+
class MAE_Encoder_FeatureExtractor(torch.nn.Module):
|
183 |
+
'''
|
184 |
+
A feature extractor that extracts features from the encoder of the Masked Autoencoder.
|
185 |
+
'''
|
186 |
+
def __init__(self, encoder : MAE_Encoder) -> None:
|
187 |
+
super().__init__()
|
188 |
+
self.cls_token = encoder.cls_token
|
189 |
+
self.pos_embedding = encoder.pos_embedding
|
190 |
+
self.patchify = encoder.patchify
|
191 |
+
self.transformer = encoder.transformer
|
192 |
+
self.layer_norm = encoder.layer_norm
|
193 |
+
|
194 |
+
def forward(self, img):
|
195 |
+
patches = self.patchify(img)
|
196 |
+
patches = rearrange(patches, 'b c h w -> (h w) b c')
|
197 |
+
patches = patches + self.pos_embedding
|
198 |
+
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
|
199 |
+
patches = rearrange(patches, 't b c -> b t c')
|
200 |
+
features = self.layer_norm(self.transformer(patches))
|
201 |
+
# t is the number of patches, b is the batch size, c is the number of features
|
202 |
+
features = rearrange(features, 'b t c -> t b c')
|
203 |
+
return features
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == '__main__':
|
207 |
+
shuffle = PatchShuffle(0.75)
|
208 |
+
a = torch.rand(16, 2, 10)
|
209 |
+
b, forward_indexes, backward_indexes = shuffle(a)
|
210 |
+
print(b.shape)
|
211 |
+
|
212 |
+
img = torch.rand(2, 3, 32, 32)
|
213 |
+
encoder = MAE_Encoder()
|
214 |
+
decoder = MAE_Decoder()
|
215 |
+
features, backward_indexes = encoder(img)
|
216 |
+
print(forward_indexes.shape)
|
217 |
+
predicted_img, mask = decoder(features, backward_indexes)
|
218 |
+
print(predicted_img.shape)
|
219 |
+
loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
|
220 |
+
print(loss)
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python=3.8
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
tensorboard
|
5 |
+
scikit-learn
|
6 |
+
matplotlib
|
7 |
+
numpy
|
8 |
+
einops
|
9 |
+
timm==0.4.12
|
10 |
+
tqdm
|
11 |
+
omega
|
12 |
+
pyyaml
|
13 |
+
opencv-python
|
14 |
+
wandb
|
15 |
+
icecream
|
16 |
+
torchinfo
|
vit-t-mae-pretrain.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:852a6a0806c42a8c725b0de82cd0e7b59d7d79ad21f8e012bc599eedcce15375
|
3 |
+
size 28972154
|