Warlord-K commited on
Commit
34a7c74
·
1 Parent(s): 6ad2059

Add Better Images and SwinFace

Browse files
app.py CHANGED
@@ -1,7 +1,8 @@
 
1
  import faiss
2
- from helpers import *
3
  import gradio as gr
4
- import os
 
5
  detector = load_detector()
6
  model = load_model()
7
 
@@ -18,14 +19,20 @@ for r, _, f in os.walk(os.getcwd() + "/images"):
18
 
19
  source_faces = []
20
  for img in source_imgs:
21
- try:
22
- source_faces.append(extract_faces(detector, img)[-1])
23
- except:
24
- print(f"{img} not found, Skipping.")
 
 
25
  source_embeddings = get_embeddings(model, source_faces)
26
 
27
  def find_names(image):
28
- imgs = extract_faces(detector, image)
 
 
 
 
29
  embeds = get_embeddings(model, imgs)
30
  d = np.zeros((len(source_embeddings), len(embeds)))
31
  for i, s in enumerate(source_embeddings):
@@ -35,12 +42,13 @@ def find_names(image):
35
  names = []
36
  for i in ids:
37
  names.append(source_imgs[i].split("/")[-1].split(".")[0])
38
- return ",".join(names)
 
39
 
40
  demo = gr.Interface(
41
  find_names,
42
  gr.Image(type="filepath"),
43
- "text",
44
  examples = [
45
  os.path.join(os.path.dirname(__file__), "examples/group1.jpg"),
46
  os.path.join(os.path.dirname(__file__), "examples/group2.jpg")
 
1
+ import os
2
  import faiss
 
3
  import gradio as gr
4
+ from helpers import *
5
+
6
  detector = load_detector()
7
  model = load_model()
8
 
 
19
 
20
  source_faces = []
21
  for img in source_imgs:
22
+ try:
23
+ faces, id = extract_faces(detector, img)
24
+ source_faces.append(faces[id])
25
+ except Exception as e:
26
+ print(f"Skipping {img}, {e}")
27
+
28
  source_embeddings = get_embeddings(model, source_faces)
29
 
30
  def find_names(image):
31
+ imgs, _ = extract_faces(detector, image)
32
+ for i, face in enumerate(imgs):
33
+ if(face.size[0] * face.size[1] < 1000):
34
+ del imgs[i]
35
+
36
  embeds = get_embeddings(model, imgs)
37
  d = np.zeros((len(source_embeddings), len(embeds)))
38
  for i, s in enumerate(source_embeddings):
 
42
  names = []
43
  for i in ids:
44
  names.append(source_imgs[i].split("/")[-1].split(".")[0])
45
+ recognition(imgs, ids, names, source_faces)
46
+ return ",".join(names), "Recognition.jpg"
47
 
48
  demo = gr.Interface(
49
  find_names,
50
  gr.Image(type="filepath"),
51
+ ["text", gr.Image(type = "filepath")],
52
  examples = [
53
  os.path.join(os.path.dirname(__file__), "examples/group1.jpg"),
54
  os.path.join(os.path.dirname(__file__), "examples/group2.jpg")
helpers.py CHANGED
@@ -1,20 +1,8 @@
1
  from ultralyticsplus import YOLO
2
  from PIL import Image
3
  import numpy as np
4
- from tensorflow.keras.models import Model, Sequential
5
- from tensorflow.keras.layers import (
6
- Convolution2D,
7
- LocallyConnected2D,
8
- MaxPooling2D,
9
- Flatten,
10
- Dense,
11
- Dropout,
12
- )
13
- import os
14
- import zipfile
15
- import gdown
16
- import tensorflow as tf
17
-
18
 
19
  def load_detector():
20
  # load model
@@ -35,46 +23,7 @@ def extract_faces(model, image):
35
  crops = []
36
  for id in ids:
37
  crops.append(Image.fromarray(np.array(img)[id[1] : id[3], id[0]: id[2]]))
38
- return crops
39
-
40
- def load_model(
41
- url="https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip",
42
- ):
43
- base_model = Sequential()
44
- base_model.add(
45
- Convolution2D(32, (11, 11), activation="relu", name="C1", input_shape=(152, 152, 3))
46
- )
47
- base_model.add(MaxPooling2D(pool_size=3, strides=2, padding="same", name="M2"))
48
- base_model.add(Convolution2D(16, (9, 9), activation="relu", name="C3"))
49
- base_model.add(LocallyConnected2D(16, (9, 9), activation="relu", name="L4"))
50
- base_model.add(LocallyConnected2D(16, (7, 7), strides=2, activation="relu", name="L5"))
51
- base_model.add(LocallyConnected2D(16, (5, 5), activation="relu", name="L6"))
52
- base_model.add(Flatten(name="F0"))
53
- base_model.add(Dense(4096, activation="relu", name="F7"))
54
- base_model.add(Dropout(rate=0.5, name="D0"))
55
- base_model.add(Dense(8631, activation="softmax", name="F8"))
56
-
57
- # ---------------------------------
58
-
59
- home = os.getcwd()
60
-
61
- if os.path.isfile(home + "/VGGFace2_DeepFace_weights_val-0.9034.h5") != True:
62
- print("VGGFace2_DeepFace_weights_val-0.9034.h5 will be downloaded...")
63
-
64
- output = home + "/VGGFace2_DeepFace_weights_val-0.9034.h5.zip"
65
-
66
- gdown.download(url, output, quiet=False)
67
-
68
- # unzip VGGFace2_DeepFace_weights_val-0.9034.h5.zip
69
- with zipfile.ZipFile(output, "r") as zip_ref:
70
- zip_ref.extractall(home)
71
-
72
- base_model.load_weights(home + "/VGGFace2_DeepFace_weights_val-0.9034.h5")
73
-
74
- # drop F8 and D0. F7 is the representation layer.
75
- deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)
76
-
77
- return deepface_model
78
 
79
  def findCosineDistance(source_representation, test_representation):
80
  a = np.matmul(np.transpose(source_representation), test_representation)
@@ -82,13 +31,22 @@ def findCosineDistance(source_representation, test_representation):
82
  c = np.sum(np.multiply(test_representation, test_representation))
83
  return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
84
 
85
- def get_embeddings(model, imgs):
86
- embeddings = []
87
- for i, img in enumerate(imgs):
88
- try:
89
- img = np.expand_dims(np.array(img.resize((152,152))), axis = 0)
90
- embedding = model.predict(img, verbose=0)[0]
91
- embeddings.append(embedding)
92
- except:
93
- print(f"Error at {i}, skipping")
94
- return embeddings
 
 
 
 
 
 
 
 
 
 
1
  from ultralyticsplus import YOLO
2
  from PIL import Image
3
  import numpy as np
4
+ from swin import load_model, get_embeddings
5
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def load_detector():
8
  # load model
 
23
  crops = []
24
  for id in ids:
25
  crops.append(Image.fromarray(np.array(img)[id[1] : id[3], id[0]: id[2]]))
26
+ return crops, np.argmax(np.array(results[0].boxes.conf))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def findCosineDistance(source_representation, test_representation):
29
  a = np.matmul(np.transpose(source_representation), test_representation)
 
31
  c = np.sum(np.multiply(test_representation, test_representation))
32
  return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
33
 
34
+ def recognition(imgs, ids, names, source_faces):
35
+ cols = 4
36
+ rows = int(np.ceil(len(imgs)/2))
37
+ img_count = 0
38
+
39
+ fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(15,rows*3))
40
+ for i in range(rows):
41
+ for j in range(cols):
42
+ if img_count < len(imgs):
43
+ if(j%2):
44
+ axes[i, j].set_title(f"Confidence: {1 - d[ids[img_count]][img_count]: .2f}")
45
+ axes[i, j].imshow(imgs[img_count])
46
+ axes[i, j].set_axis_off()
47
+ img_count+=1
48
+ else:
49
+ axes[i, j].set_title(f"Roll No.:{source_imgs[ids[img_count]].split('/')[-1].split('.')[0]}")
50
+ axes[i, j].imshow(source_faces[ids[img_count]])
51
+ axes[i, j].set_axis_off()
52
+ plt.savefig("Recognition.jpg")
images/220001002.jpg CHANGED

Git LFS Details

  • SHA256: f990e3c20b73586374029a6197efe7382222bd3657d083298beea73df42ccb41
  • Pointer size: 132 Bytes
  • Size of remote file: 3.64 MB

Git LFS Details

  • SHA256: 24406b68f6255d3bfc831356101071233dfb26d1aa586933da84e721e9c7fbc1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
images/220001015.png CHANGED

Git LFS Details

  • SHA256: 0240cbd417522f86bd26ec141a348a0bfbc562ba8e255cd3e328548a4759bddf
  • Pointer size: 132 Bytes
  • Size of remote file: 2.51 MB

Git LFS Details

  • SHA256: 04d559adf38714b8f7ccb58f222e7a1e5d11d86368f81a0cb49a0f3faeb95458
  • Pointer size: 132 Bytes
  • Size of remote file: 3.2 MB
images/220001018.jpeg CHANGED

Git LFS Details

  • SHA256: cfd0c806d45918f446865ccea0f238dbc506e3542f9f26041edfeba642e31de3
  • Pointer size: 132 Bytes
  • Size of remote file: 3 MB

Git LFS Details

  • SHA256: 6f005c3a160f86e51cc84547b1bdeedc7054c2f8e264b8e3ce215656b9d7d87b
  • Pointer size: 131 Bytes
  • Size of remote file: 953 kB
images/220001023.jpg CHANGED
images/220001037.jpg CHANGED
images/220001042.jpg CHANGED

Git LFS Details

  • SHA256: 3668186a1bf544c33059cbad3660ce761d880d785ad3201620e97597e9a466e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.26 MB

Git LFS Details

  • SHA256: cb0c9cb044e2d7b66c08634c1f178764fbd1642fd570835e3deff04b89aec2d2
  • Pointer size: 131 Bytes
  • Size of remote file: 410 kB
images/220001048.jpg CHANGED

Git LFS Details

  • SHA256: dac24712be6fcdcf976c82de8a6112788f72760d825ef9f67e9fb0cd277b87ad
  • Pointer size: 132 Bytes
  • Size of remote file: 6.1 MB

Git LFS Details

  • SHA256: f8d19d1cf6bc65ee86d2fc011167f5d96981884a0ffadd3be9adf8c62af89aae
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
images/220001050.jpg CHANGED

Git LFS Details

  • SHA256: cea1d4d2059d83f3e1998214dac2e84306943d0024ecab7b17fcb26301ba9c40
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB

Git LFS Details

  • SHA256: c41261694f7506f3de1411d70ac63feee1d596346829955919088dde6c9e23c4
  • Pointer size: 131 Bytes
  • Size of remote file: 959 kB
images/220001054.jpg CHANGED
images/220001055.png CHANGED

Git LFS Details

  • SHA256: 1f1d23411cb763bd08118b396caf303d9bc99fb61f9856abc61c4b7b3fcaba9a
  • Pointer size: 132 Bytes
  • Size of remote file: 4.14 MB

Git LFS Details

  • SHA256: 28df2f657d0a326b0ec2d0d0a3634591c693fc0dbe5a47ec5be0f25c9baab2b2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
images/220001060.jpg CHANGED

Git LFS Details

  • SHA256: e3922016f14810a858051b7b8f3b6b7d3dca51d4ea4fdd32db5d6f704c813346
  • Pointer size: 132 Bytes
  • Size of remote file: 1.59 MB

Git LFS Details

  • SHA256: bb417581fbd5dd234bcc2a8bca7e4e6897af05fa6498c9bdd83d2fff9bd41416
  • Pointer size: 131 Bytes
  • Size of remote file: 597 kB
images/220001064.jpg CHANGED

Git LFS Details

  • SHA256: 2ca9ad2b0450840165216b98236c3ed912364b0c6f787dc37f758c1b9ab69051
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB

Git LFS Details

  • SHA256: cecec826981f27d95465bbffeba76b6d308671f287045e93571f208f93b4ca34
  • Pointer size: 131 Bytes
  • Size of remote file: 338 kB
images/220001065.png CHANGED

Git LFS Details

  • SHA256: d345ab9b5fa8ac1be7d7f45e356f372b5587a9712b835a918bb586ca56b271a4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB

Git LFS Details

  • SHA256: d1f323fb84c52f44cdcad1cd93b5c3d5a4d326bee3998209d9f5514f0548ec86
  • Pointer size: 132 Bytes
  • Size of remote file: 4.4 MB
images/220001068.jpg CHANGED

Git LFS Details

  • SHA256: d737f49fb02791ff796de15550a3f737eb812303b15c595375f57fd0c05868f8
  • Pointer size: 132 Bytes
  • Size of remote file: 7.2 MB

Git LFS Details

  • SHA256: aecedca41d69e9550ea752db73ffda2ad8fad36b9ad09c4e612ceba71f9f20a3
  • Pointer size: 131 Bytes
  • Size of remote file: 731 kB
images/Akshit.jpeg DELETED
Binary file (7.71 kB)
 
images/Gourav.jpeg DELETED
Binary file (22.7 kB)
 
images/Gourav1.jpeg DELETED
Binary file (37.3 kB)
 
images/Jayant.jpeg DELETED
Binary file (52.1 kB)
 
images/Priyanshu.jpeg DELETED
Binary file (22.3 kB)
 
images/Priyanshu1.jpeg DELETED
Binary file (24.5 kB)
 
images/Sairaj.jpeg DELETED
Binary file (158 kB)
 
images/Samip.jpeg DELETED
Binary file (30.7 kB)
 
images/Sekhar.jpeg DELETED
Binary file (143 kB)
 
images/Vikas.jpeg DELETED
Binary file (21.4 kB)
 
images/Yatharth.jpeg DELETED
Binary file (28.7 kB)
 
requirements.txt CHANGED
@@ -7,4 +7,7 @@ tensorflow
7
  numpy
8
  gdown
9
  pillow
10
- gradio
 
 
 
 
7
  numpy
8
  gdown
9
  pillow
10
+ gradio
11
+ opencv-python
12
+ timm
13
+ torch
swin.py ADDED
@@ -0,0 +1,1164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint as checkpoint
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+ import os
10
+ import gdown
11
+ class Mlp(nn.Module):
12
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
13
+ super().__init__()
14
+ out_features = out_features or in_features
15
+ hidden_features = hidden_features or in_features
16
+ self.fc1 = nn.Linear(in_features, hidden_features)
17
+ self.act = act_layer()
18
+ self.fc2 = nn.Linear(hidden_features, out_features)
19
+ self.drop = nn.Dropout(drop)
20
+
21
+ def forward(self, x):
22
+ x = self.fc1(x)
23
+ x = self.act(x)
24
+ x = self.drop(x)
25
+ x = self.fc2(x)
26
+ x = self.drop(x)
27
+ return x
28
+
29
+
30
+ def window_partition(x, window_size):
31
+ """
32
+ Args:
33
+ x: (B, H, W, C)
34
+ window_size (int): window size
35
+
36
+ Returns:
37
+ windows: (num_windows*B, window_size, window_size, C)
38
+ """
39
+ B, H, W, C = x.shape
40
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
41
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
42
+ return windows
43
+
44
+
45
+ def window_reverse(windows, window_size, H, W):
46
+ """
47
+ Args:
48
+ windows: (num_windows*B, window_size, window_size, C)
49
+ window_size (int): Window size
50
+ H (int): Height of image
51
+ W (int): Width of image
52
+
53
+ Returns:
54
+ x: (B, H, W, C)
55
+ """
56
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
57
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
59
+ return x
60
+
61
+
62
+ class WindowAttention(nn.Module):
63
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
64
+ It supports both of shifted and non-shifted window.
65
+
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
72
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
73
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
77
+
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.window_size = window_size # Wh, Ww
81
+ self.num_heads = num_heads
82
+ head_dim = dim // num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ # define a parameter table of relative position bias
86
+ self.relative_position_bias_table = nn.Parameter(
87
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88
+
89
+ # get pair-wise relative position index for each token inside the window
90
+ coords_h = torch.arange(self.window_size[0])
91
+ coords_w = torch.arange(self.window_size[1])
92
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
93
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
96
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
97
+ relative_coords[:, :, 1] += self.window_size[1] - 1
98
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
99
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
100
+ self.register_buffer("relative_position_index", relative_position_index)
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.attn_drop = nn.Dropout(attn_drop)
104
+ self.proj = nn.Linear(dim, dim)
105
+ self.proj_drop = nn.Dropout(proj_drop)
106
+
107
+ trunc_normal_(self.relative_position_bias_table, std=.02)
108
+ self.softmax = nn.Softmax(dim=-1)
109
+
110
+ def forward(self, x, mask=None):
111
+ """
112
+ Args:
113
+ x: input features with shape of (num_windows*B, N, C)
114
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
115
+ """
116
+ with torch.cuda.amp.autocast(True):
117
+ B_, N, C = x.shape
118
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
119
+
120
+ with torch.cuda.amp.autocast(False):
121
+ q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float() # make torchscript happy (cannot use tensor as tuple)
122
+
123
+ q = q * self.scale
124
+ attn = (q @ k.transpose(-2, -1))
125
+
126
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
127
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
128
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
129
+ attn = attn + relative_position_bias.unsqueeze(0)
130
+
131
+ if mask is not None:
132
+ nW = mask.shape[0]
133
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
134
+ attn = attn.view(-1, self.num_heads, N, N)
135
+ attn = self.softmax(attn)
136
+ else:
137
+ attn = self.softmax(attn)
138
+
139
+ attn = self.attn_drop(attn)
140
+
141
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
142
+
143
+ with torch.cuda.amp.autocast(True):
144
+ x = self.proj(x)
145
+ x = self.proj_drop(x)
146
+ return x
147
+
148
+ def extra_repr(self) -> str:
149
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
150
+
151
+ def flops(self, N):
152
+ # calculate flops for 1 window with token length of N
153
+ flops = 0
154
+ # qkv = self.qkv(x)
155
+ flops += N * self.dim * 3 * self.dim
156
+ # attn = (q @ k.transpose(-2, -1))
157
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
158
+ # x = (attn @ v)
159
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
160
+ # x = self.proj(x)
161
+ flops += N * self.dim * self.dim
162
+ return flops
163
+
164
+
165
+ class SwinTransformerBlock(nn.Module):
166
+ r""" Swin Transformer Block.
167
+
168
+ Args:
169
+ dim (int): Number of input channels.
170
+ input_resolution (tuple[int]): Input resulotion.
171
+ num_heads (int): Number of attention heads.
172
+ window_size (int): Window size.
173
+ shift_size (int): Shift size for SW-MSA.
174
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
175
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
176
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
177
+ drop (float, optional): Dropout rate. Default: 0.0
178
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
179
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
180
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
181
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
182
+ """
183
+
184
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
185
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
186
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
187
+ super().__init__()
188
+ self.dim = dim
189
+ self.input_resolution = input_resolution
190
+ self.num_heads = num_heads
191
+ self.window_size = window_size
192
+ self.shift_size = shift_size
193
+ self.mlp_ratio = mlp_ratio
194
+ if min(self.input_resolution) <= self.window_size:
195
+ # if window size is larger than input resolution, we don't partition windows
196
+ self.shift_size = 0
197
+ self.window_size = min(self.input_resolution)
198
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
199
+
200
+ self.norm1 = norm_layer(dim)
201
+ self.attn = WindowAttention(
202
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
203
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
204
+
205
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
206
+ self.norm2 = norm_layer(dim)
207
+ mlp_hidden_dim = int(dim * mlp_ratio)
208
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
209
+
210
+ if self.shift_size > 0:
211
+ # calculate attention mask for SW-MSA
212
+ H, W = self.input_resolution
213
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
214
+ h_slices = (slice(0, -self.window_size),
215
+ slice(-self.window_size, -self.shift_size),
216
+ slice(-self.shift_size, None))
217
+ w_slices = (slice(0, -self.window_size),
218
+ slice(-self.window_size, -self.shift_size),
219
+ slice(-self.shift_size, None))
220
+ cnt = 0
221
+ for h in h_slices:
222
+ for w in w_slices:
223
+ img_mask[:, h, w, :] = cnt
224
+ cnt += 1
225
+
226
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
227
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
228
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
229
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
230
+ else:
231
+ attn_mask = None
232
+
233
+ self.register_buffer("attn_mask", attn_mask)
234
+
235
+ def forward(self, x):
236
+ H, W = self.input_resolution
237
+ B, L, C = x.shape
238
+ assert L == H * W, "input feature has wrong size"
239
+
240
+ shortcut = x
241
+ x = self.norm1(x)
242
+ x = x.view(B, H, W, C)
243
+
244
+ # cyclic shift
245
+ if self.shift_size > 0:
246
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
247
+ else:
248
+ shifted_x = x
249
+
250
+ # partition windows
251
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
252
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
253
+
254
+ # W-MSA/SW-MSA
255
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
256
+
257
+ # merge windows
258
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
259
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
260
+
261
+ # reverse cyclic shift
262
+ if self.shift_size > 0:
263
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
264
+ else:
265
+ x = shifted_x
266
+ x = x.view(B, H * W, C)
267
+
268
+ # FFN
269
+ x = shortcut + self.drop_path(x)
270
+ with torch.cuda.amp.autocast(True):
271
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
272
+
273
+ return x
274
+
275
+ def extra_repr(self) -> str:
276
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
277
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
278
+
279
+ def flops(self):
280
+ flops = 0
281
+ H, W = self.input_resolution
282
+ # norm1
283
+ flops += self.dim * H * W
284
+ # W-MSA/SW-MSA
285
+ nW = H * W / self.window_size / self.window_size
286
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
287
+ # mlp
288
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
289
+ # norm2
290
+ flops += self.dim * H * W
291
+ return flops
292
+
293
+
294
+ class PatchMerging(nn.Module):
295
+ r""" Patch Merging Layer.
296
+
297
+ Args:
298
+ input_resolution (tuple[int]): Resolution of input feature.
299
+ dim (int): Number of input channels.
300
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
301
+ """
302
+
303
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
304
+ super().__init__()
305
+ self.input_resolution = input_resolution
306
+ self.dim = dim
307
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
308
+ self.norm = norm_layer(4 * dim)
309
+
310
+ def forward(self, x):
311
+ """
312
+ x: B, H*W, C
313
+ """
314
+ H, W = self.input_resolution
315
+ B, L, C = x.shape
316
+ assert L == H * W, "input feature has wrong size"
317
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
318
+
319
+ x = x.view(B, H, W, C)
320
+
321
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
322
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
323
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
324
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
325
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
326
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
327
+
328
+ x = self.norm(x)
329
+ x = self.reduction(x)
330
+
331
+ return x
332
+
333
+ def extra_repr(self) -> str:
334
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
335
+
336
+ def flops(self):
337
+ H, W = self.input_resolution
338
+ flops = H * W * self.dim
339
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
340
+ return flops
341
+
342
+
343
+ class BasicLayer(nn.Module):
344
+ """ A basic Swin Transformer layer for one stage.
345
+
346
+ Args:
347
+ dim (int): Number of input channels.
348
+ input_resolution (tuple[int]): Input resolution.
349
+ depth (int): Number of blocks.
350
+ num_heads (int): Number of attention heads.
351
+ window_size (int): Local window size.
352
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
353
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
354
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
355
+ drop (float, optional): Dropout rate. Default: 0.0
356
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
357
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
358
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
359
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
360
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
361
+ """
362
+
363
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
364
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
365
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
366
+
367
+ super().__init__()
368
+ self.dim = dim
369
+ self.input_resolution = input_resolution
370
+ self.depth = depth
371
+ self.use_checkpoint = use_checkpoint
372
+
373
+ # build blocks
374
+ self.blocks = nn.ModuleList([
375
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
376
+ num_heads=num_heads, window_size=window_size,
377
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
378
+ mlp_ratio=mlp_ratio,
379
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
380
+ drop=drop, attn_drop=attn_drop,
381
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
382
+ norm_layer=norm_layer)
383
+ for i in range(depth)])
384
+
385
+ # patch merging layer
386
+ if downsample is not None:
387
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
388
+ else:
389
+ self.downsample = None
390
+
391
+ def forward(self, x):
392
+ for blk in self.blocks:
393
+ if self.use_checkpoint:
394
+ x = checkpoint.checkpoint(blk, x)
395
+ else:
396
+ x = blk(x)
397
+ if self.downsample is not None:
398
+ x = self.downsample(x)
399
+ return x
400
+
401
+ def extra_repr(self) -> str:
402
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
403
+
404
+ def flops(self):
405
+ flops = 0
406
+ for blk in self.blocks:
407
+ flops += blk.flops()
408
+ if self.downsample is not None:
409
+ flops += self.downsample.flops()
410
+ return flops
411
+
412
+
413
+ class PatchEmbed(nn.Module):
414
+ r""" Image to Patch Embedding
415
+
416
+ Args:
417
+ img_size (int): Image size. Default: 224.
418
+ patch_size (int): Patch token size. Default: 4.
419
+ in_chans (int): Number of input image channels. Default: 3.
420
+ embed_dim (int): Number of linear projection output channels. Default: 96.
421
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
422
+ """
423
+
424
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
425
+ super().__init__()
426
+ img_size = to_2tuple(img_size)
427
+ patch_size = to_2tuple(patch_size)
428
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
429
+ self.img_size = img_size
430
+ self.patch_size = patch_size
431
+ self.patches_resolution = patches_resolution
432
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
433
+
434
+ self.in_chans = in_chans
435
+ self.embed_dim = embed_dim
436
+
437
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
438
+ if norm_layer is not None:
439
+ self.norm = norm_layer(embed_dim)
440
+ else:
441
+ self.norm = None
442
+
443
+ def forward(self, x):
444
+ B, C, H, W = x.shape
445
+ # FIXME look at relaxing size constraints
446
+ assert H == self.img_size[0] and W == self.img_size[1], \
447
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
448
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
449
+ if self.norm is not None:
450
+ x = self.norm(x)
451
+ return x
452
+
453
+ def flops(self):
454
+ Ho, Wo = self.patches_resolution
455
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
456
+ if self.norm is not None:
457
+ flops += Ho * Wo * self.embed_dim
458
+ return flops
459
+
460
+
461
+ class SwinTransformer(nn.Module):
462
+ r""" Swin Transformer
463
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
464
+ https://arxiv.org/pdf/2103.14030
465
+
466
+ Args:
467
+ img_size (int | tuple(int)): Input image size. Default 224
468
+ patch_size (int | tuple(int)): Patch size. Default: 4
469
+ in_chans (int): Number of input image channels. Default: 3
470
+ num_classes (int): Number of classes for classification head. Default: 1000
471
+ embed_dim (int): Patch embedding dimension. Default: 96
472
+ depths (tuple(int)): Depth of each Swin Transformer layer.
473
+ num_heads (tuple(int)): Number of attention heads in different layers.
474
+ window_size (int): Window size. Default: 7
475
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
476
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
477
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
478
+ drop_rate (float): Dropout rate. Default: 0
479
+ attn_drop_rate (float): Attention dropout rate. Default: 0
480
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
481
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
482
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
483
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
484
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
485
+ """
486
+
487
+ def __init__(self, img_size=112, patch_size=2, in_chans=3, num_classes=1000,
488
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
489
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
490
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
491
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
492
+ use_checkpoint=False, **kwargs):
493
+ super().__init__()
494
+
495
+ self.num_classes = num_classes
496
+ self.num_layers = len(depths)
497
+ self.embed_dim = embed_dim
498
+ self.ape = ape
499
+ self.patch_norm = patch_norm
500
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
501
+ self.mlp_ratio = mlp_ratio
502
+
503
+ # split image into non-overlapping patches
504
+ self.patch_embed = PatchEmbed(
505
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
506
+ norm_layer=norm_layer if self.patch_norm else None)
507
+ num_patches = self.patch_embed.num_patches
508
+ patches_resolution = self.patch_embed.patches_resolution
509
+ self.patches_resolution = patches_resolution
510
+
511
+ # absolute position embedding
512
+ if self.ape:
513
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
514
+ trunc_normal_(self.absolute_pos_embed, std=.02)
515
+
516
+ self.pos_drop = nn.Dropout(p=drop_rate)
517
+
518
+ # stochastic depth
519
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
520
+
521
+ # build layers
522
+ self.layers = nn.ModuleList()
523
+ for i_layer in range(self.num_layers):
524
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
525
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
526
+ patches_resolution[1] // (2 ** i_layer)),
527
+ depth=depths[i_layer],
528
+ num_heads=num_heads[i_layer],
529
+ window_size=window_size,
530
+ mlp_ratio=self.mlp_ratio,
531
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
532
+ drop=drop_rate, attn_drop=attn_drop_rate,
533
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
534
+ norm_layer=norm_layer,
535
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
536
+ use_checkpoint=use_checkpoint)
537
+ self.layers.append(layer)
538
+
539
+ self.norm = norm_layer(self.num_features)
540
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
541
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
542
+
543
+ self.feature = nn.Sequential(
544
+ nn.Linear(in_features=self.num_features, out_features=self.num_features, bias=False),
545
+ nn.BatchNorm1d(num_features=self.num_features, eps=2e-5),
546
+ nn.Linear(in_features=self.num_features, out_features=num_classes, bias=False),
547
+ nn.BatchNorm1d(num_features=num_classes, eps=2e-5)
548
+ )
549
+ self.feature_resolution = (patches_resolution[0] // (2 ** (self.num_layers-1)), patches_resolution[1] // (2 ** (self.num_layers-1)))
550
+
551
+
552
+ self.apply(self._init_weights)
553
+
554
+ def _init_weights(self, m):
555
+ if isinstance(m, nn.Linear):
556
+ trunc_normal_(m.weight, std=.02)
557
+ if isinstance(m, nn.Linear) and m.bias is not None:
558
+ nn.init.constant_(m.bias, 0)
559
+ elif isinstance(m, nn.LayerNorm):
560
+ nn.init.constant_(m.bias, 0)
561
+ nn.init.constant_(m.weight, 1.0)
562
+
563
+ @torch.jit.ignore
564
+ def no_weight_decay(self):
565
+ return {'absolute_pos_embed'}
566
+
567
+ @torch.jit.ignore
568
+ def no_weight_decay_keywords(self):
569
+ return {'relative_position_bias_table'}
570
+
571
+ def forward_features(self, x):
572
+
573
+ patches_resolution = self.patch_embed.patches_resolution
574
+
575
+ x = self.patch_embed(x)
576
+ if self.ape:
577
+ x = x + self.absolute_pos_embed
578
+ x = self.pos_drop(x)
579
+
580
+ local_features = []
581
+ i = 0
582
+ for layer in self.layers:
583
+ i += 1
584
+ x = layer(x)
585
+
586
+ if not i == self.num_layers:
587
+
588
+ H = patches_resolution[0] // (2 ** i)
589
+ W = patches_resolution[1] // (2 ** i)
590
+
591
+ B, L, C = x.shape
592
+
593
+ temp = x.transpose(1, 2).reshape(B, C, H, W)
594
+ win_h = H // self.feature_resolution[0]
595
+ win_w = W // self.feature_resolution[1]
596
+ if not (win_h == 1 and win_w == 1):
597
+ temp = F.avg_pool2d(temp, kernel_size=(win_h, win_w))
598
+ local_features.append(temp)
599
+
600
+
601
+ local_features = torch.cat(local_features, dim=1)
602
+ # B, C, H, W
603
+ global_features = x
604
+ B, L, C = global_features.shape
605
+ global_features = global_features.transpose(1, 2).reshape(B, C, self.feature_resolution[0], self.feature_resolution[1])
606
+ # B, C, H, W
607
+
608
+ x = self.norm(x) # B L C
609
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
610
+ x = torch.flatten(x, 1)
611
+ return local_features, global_features, x
612
+
613
+
614
+ def forward(self, x):
615
+ local_features, global_features, x = self.forward_features(x)
616
+ x = self.feature(x)
617
+ return local_features, global_features, x
618
+
619
+ def flops(self):
620
+ flops = 0
621
+ flops += self.patch_embed.flops()
622
+ for i, layer in enumerate(self.layers):
623
+ flops += layer.flops()
624
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
625
+ flops += self.num_features * self.num_classes
626
+ return flops
627
+
628
+ class BasicConv(nn.Module):
629
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
630
+ super(BasicConv, self).__init__()
631
+ self.out_channels = out_planes
632
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
633
+ self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
634
+ self.relu = nn.ReLU() if relu else None
635
+
636
+ def forward(self, x):
637
+ x = self.conv(x)
638
+ if self.bn is not None:
639
+ x = self.bn(x)
640
+ if self.relu is not None:
641
+ x = self.relu(x)
642
+ return x
643
+
644
+ class Flatten(nn.Module):
645
+ def forward(self, x):
646
+ return x.view(x.size(0), -1)
647
+
648
+ class ChannelGate(nn.Module):
649
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
650
+ super(ChannelGate, self).__init__()
651
+ self.gate_channels = gate_channels
652
+ self.mlp = nn.Sequential(
653
+ Flatten(),
654
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
655
+ nn.ReLU(),
656
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
657
+ )
658
+ self.pool_types = pool_types
659
+ def forward(self, x):
660
+ channel_att_sum = None
661
+ for pool_type in self.pool_types:
662
+ if pool_type=='avg':
663
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
664
+ channel_att_raw = self.mlp( avg_pool )
665
+ elif pool_type=='max':
666
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
667
+ channel_att_raw = self.mlp( max_pool )
668
+ elif pool_type=='lp':
669
+ lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
670
+ channel_att_raw = self.mlp( lp_pool )
671
+ elif pool_type=='lse':
672
+ # LSE pool only
673
+ lse_pool = logsumexp_2d(x)
674
+ channel_att_raw = self.mlp( lse_pool )
675
+
676
+ if channel_att_sum is None:
677
+ channel_att_sum = channel_att_raw
678
+ else:
679
+ channel_att_sum = channel_att_sum + channel_att_raw
680
+
681
+ scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
682
+ return x * scale
683
+
684
+ def logsumexp_2d(tensor):
685
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
686
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
687
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
688
+ return outputs
689
+
690
+ class ChannelPool(nn.Module):
691
+ def forward(self, x):
692
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
693
+
694
+ class SpatialGate(nn.Module):
695
+ def __init__(self):
696
+ super(SpatialGate, self).__init__()
697
+ kernel_size = 7
698
+ self.compress = ChannelPool()
699
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
700
+ def forward(self, x):
701
+ x_compress = self.compress(x)
702
+ x_out = self.spatial(x_compress)
703
+ scale = F.sigmoid(x_out) # broadcasting
704
+ return x * scale
705
+
706
+ class CBAM(nn.Module):
707
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
708
+ super(CBAM, self).__init__()
709
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
710
+ self.no_spatial=no_spatial
711
+ if not no_spatial:
712
+ self.SpatialGate = SpatialGate()
713
+ def forward(self, x):
714
+ x_out = self.ChannelGate(x)
715
+ if not self.no_spatial:
716
+ x_out = self.SpatialGate(x_out)
717
+ return x_out
718
+
719
+
720
+ class ConvLayer(torch.nn.Module):
721
+
722
+ def __init__(self, in_chans=768, out_chans=512, conv_mode="normal", kernel_size=3):
723
+ super().__init__()
724
+ self.conv_mode = conv_mode
725
+
726
+ if conv_mode == "normal":
727
+ self.conv = nn.Conv2d(in_chans, out_chans, kernel_size, stride=1, padding=(kernel_size-1)//2, bias=False)
728
+ elif conv_mode == "split":
729
+ self.convs = nn.ModuleList()
730
+ for j in range(len(in_chans)):
731
+ conv = nn.Conv2d(in_chans[j], out_chans[j], kernel_size, stride=1, padding=(kernel_size-1)//2, bias=False)
732
+ self.convs.append(conv)
733
+
734
+ self.cut = [0 for i in range(len(in_chans)+1)]
735
+ self.cut[0] = 0
736
+ for i in range(1, len(in_chans)+1):
737
+ self.cut[i] = self.cut[i - 1] + in_chans[i-1]
738
+
739
+ def forward(self, x):
740
+ if self.conv_mode == "normal":
741
+ x = self.conv(x)
742
+
743
+ elif self.conv_mode == "split":
744
+ outputs = []
745
+ for j in range(len(self.cut)-1):
746
+ input_map = x[:, self.cut[j]:self.cut[j+1]]
747
+ #print(input_map.shape)
748
+ output_map = self.convs[j](input_map)
749
+ outputs.append(output_map)
750
+ #print(output_map.shape)
751
+ x = torch.cat(outputs, dim=1)
752
+
753
+ return x
754
+
755
+
756
+ class LANet(torch.nn.Module):
757
+ def __init__(self, in_chans=512, reduction_ratio=2.0):
758
+ super().__init__()
759
+
760
+ self.in_chans = in_chans
761
+ self.mid_chans = int(self.in_chans/reduction_ratio)
762
+
763
+ self.conv1 = nn.Conv2d(self.in_chans, self.mid_chans, kernel_size=(1, 1), stride=(1, 1), bias=False)
764
+ self.conv2 = nn.Conv2d(self.mid_chans, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
765
+
766
+ def forward(self, x):
767
+
768
+ x = F.relu(self.conv1(x))
769
+ x = torch.sigmoid(self.conv2(x))
770
+
771
+ return x
772
+
773
+
774
+ def MAD(x, p=0.6):
775
+ B, C, W, H = x.shape
776
+
777
+ mask1 = torch.cat([torch.randperm(C).unsqueeze(dim=0) for j in range(B)], dim=0).cuda()
778
+ mask2 = torch.rand([B, C]).cuda()
779
+ ones = torch.ones([B, C], dtype=torch.float).cuda()
780
+ zeros = torch.zeros([B, C], dtype=torch.float).cuda()
781
+ mask = torch.where(mask1 == 0, zeros, ones)
782
+ mask = torch.where(mask2 < p, mask, ones)
783
+
784
+ x = x.permute(2, 3, 0, 1)
785
+ x = x.mul(mask)
786
+ x = x.permute(2, 3, 0, 1)
787
+ return x
788
+
789
+
790
+ class LANets(torch.nn.Module):
791
+
792
+ def __init__(self, branch_num=2, feature_dim=512, la_reduction_ratio=2.0, MAD=MAD):
793
+ super().__init__()
794
+
795
+ self.LANets = nn.ModuleList()
796
+ for i in range(branch_num):
797
+ self.LANets.append(LANet(in_chans=feature_dim, reduction_ratio=la_reduction_ratio))
798
+
799
+ self.MAD = MAD
800
+ self.branch_num = branch_num
801
+
802
+ def forward(self, x):
803
+
804
+ B, C, W, H = x.shape
805
+
806
+ outputs = []
807
+ for lanet in self.LANets:
808
+ output = lanet(x)
809
+ outputs.append(output)
810
+
811
+ LANets_output = torch.cat(outputs, dim=1)
812
+
813
+ if self.MAD and self.branch_num != 1:
814
+ LANets_output = self.MAD(LANets_output)
815
+
816
+ mask = torch.max(LANets_output, dim=1).values.reshape(B, 1, W, H)
817
+ x = x.mul(mask)
818
+
819
+ return x
820
+
821
+
822
+ class FeatureAttentionNet(torch.nn.Module):
823
+ def __init__(self, in_chans=768, feature_dim=512, kernel_size=3,
824
+ conv_shared=False, conv_mode="normal",
825
+ channel_attention=None, spatial_attention=None,
826
+ pooling="max", la_branch_num=2):
827
+ super().__init__()
828
+
829
+ self.conv_shared = conv_shared
830
+ self.channel_attention = channel_attention
831
+ self.spatial_attention = spatial_attention
832
+
833
+ if not self.conv_shared:
834
+ if conv_mode == "normal":
835
+ self.conv = ConvLayer(in_chans=in_chans, out_chans=feature_dim,
836
+ conv_mode="normal", kernel_size=kernel_size)
837
+ elif conv_mode == "split" and in_chans == 2112:
838
+ self.conv = ConvLayer(in_chans=[192, 384, 768, 768], out_chans=[47, 93, 186, 186],
839
+ conv_mode="split", kernel_size=kernel_size)
840
+
841
+ if self.channel_attention == "CBAM":
842
+ self.channel_attention = ChannelGate(gate_channels=feature_dim)
843
+
844
+ if self.spatial_attention == "CBAM":
845
+ self.spatial_attention = SpatialGate()
846
+ elif self.spatial_attention == "LANet":
847
+ self.spatial_attention = LANets(branch_num=la_branch_num, feature_dim=feature_dim)
848
+
849
+ if pooling == "max":
850
+ self.pool = nn.AdaptiveMaxPool2d((1, 1))
851
+ elif pooling == "avg":
852
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
853
+
854
+ self.act = nn.ReLU(inplace=True)
855
+ self.norm = nn.BatchNorm1d(num_features=feature_dim, eps=2e-5)
856
+
857
+ def forward(self, x):
858
+
859
+ if not self.conv_shared:
860
+ x = self.conv(x)
861
+
862
+ if self.channel_attention:
863
+ x = self.channel_attention(x)
864
+
865
+ if self.spatial_attention:
866
+ x = self.spatial_attention(x)
867
+
868
+ x = self.act(x)
869
+ B, C, _, __ = x.shape
870
+ x = self.pool(x).reshape(B, C)
871
+ x = self.norm(x)
872
+
873
+ return x
874
+
875
+
876
+ class FeatureAttentionModule(torch.nn.Module):
877
+ def __init__(self, branch_num=11, in_chans=2112, feature_dim=512, conv_shared=False, conv_mode="split", kernel_size=3,
878
+ channel_attention="CBAM", spatial_attention=None, la_num_list=[2 for j in range(11)], pooling="max"):
879
+ super().__init__()
880
+
881
+
882
+ self.conv_shared = conv_shared
883
+ if self.conv_shared:
884
+ if conv_mode == "normal":
885
+ self.conv = ConvLayer(in_chans=in_chans, out_chans=feature_dim,
886
+ conv_mode="normal", kernel_size=kernel_size)
887
+ elif conv_mode == "split" and in_chans == 2112:
888
+ self.conv = ConvLayer(in_chans=[192, 384, 768, 768], out_chans=[47, 93, 186, 186],
889
+ conv_mode="split", kernel_size=kernel_size)
890
+
891
+ self.nets = nn.ModuleList()
892
+ for i in range(branch_num):
893
+ net = FeatureAttentionNet(in_chans=in_chans, feature_dim=feature_dim,
894
+ conv_shared=conv_shared, conv_mode=conv_mode, kernel_size=kernel_size,
895
+ channel_attention=channel_attention, spatial_attention=spatial_attention,
896
+ la_branch_num=la_num_list[i], pooling=pooling)
897
+ self.nets.append(net)
898
+
899
+ self.apply(self._init_weights)
900
+
901
+ def _init_weights(self, m):
902
+ if isinstance(m, nn.Linear):
903
+ trunc_normal_(m.weight, std=.02)
904
+ if isinstance(m, nn.Linear) and m.bias is not None:
905
+ nn.init.constant_(m.bias, 0)
906
+ elif isinstance(m, nn.LayerNorm):
907
+ nn.init.constant_(m.bias, 0)
908
+ nn.init.constant_(m.weight, 1.0)
909
+
910
+ def forward(self, x):
911
+
912
+ if self.conv_shared:
913
+ x = self.conv(x)
914
+
915
+ outputs = []
916
+ for net in self.nets:
917
+ output = net(x).unsqueeze(dim=0)
918
+ outputs.append(output)
919
+ outputs = torch.cat(outputs, dim=0)
920
+
921
+ return outputs
922
+
923
+ class TaskSpecificSubnet(torch.nn.Module):
924
+ def __init__(self, feature_dim=512, drop_rate=0.5):
925
+ super().__init__()
926
+ self.feature = nn.Sequential(
927
+ nn.Linear(feature_dim, feature_dim),
928
+ nn.ReLU(True),
929
+ nn.Dropout(drop_rate),
930
+ nn.Linear(feature_dim, feature_dim),
931
+ nn.ReLU(True),
932
+ nn.Dropout(drop_rate),)
933
+
934
+ def forward(self, x):
935
+ return self.feature(x)
936
+
937
+ class TaskSpecificSubnets(torch.nn.Module):
938
+ def __init__(self, branch_num=11):
939
+ super().__init__()
940
+
941
+ self.branch_num = branch_num
942
+ self.nets = nn.ModuleList()
943
+ for i in range(self.branch_num):
944
+ net = TaskSpecificSubnet(drop_rate=0.5)
945
+ self.nets.append(net)
946
+
947
+ self.apply(self._init_weights)
948
+
949
+ def _init_weights(self, m):
950
+ if isinstance(m, nn.Linear):
951
+ trunc_normal_(m.weight, std=.02)
952
+ if isinstance(m, nn.Linear) and m.bias is not None:
953
+ nn.init.constant_(m.bias, 0)
954
+ elif isinstance(m, nn.LayerNorm):
955
+ nn.init.constant_(m.bias, 0)
956
+ nn.init.constant_(m.weight, 1.0)
957
+
958
+ def forward(self, x):
959
+
960
+ outputs = []
961
+ for i in range(self.branch_num):
962
+ net = self.nets[i]
963
+ output = net(x[i]).unsqueeze(dim=0)
964
+ outputs.append(output)
965
+ outputs = torch.cat(outputs, dim=0)
966
+
967
+ return outputs
968
+
969
+ class OutputModule(torch.nn.Module):
970
+ def __init__(self, feature_dim=512, output_type="Dict"):
971
+ super().__init__()
972
+ self.output_sizes = [[2],
973
+ [1, 2],
974
+ [7, 2],
975
+ [2 for j in range(6)],
976
+ [2 for j in range(10)],
977
+ [2 for j in range(5)],
978
+ [2, 2],
979
+ [2 for j in range(4)],
980
+ [2 for j in range(6)],
981
+ [2, 2],
982
+ [2, 2]]
983
+
984
+ self.output_fcs = nn.ModuleList()
985
+ for i in range(0, len(self.output_sizes)):
986
+ for j in range(len(self.output_sizes[i])):
987
+ output_fc = nn.Linear(feature_dim, self.output_sizes[i][j])
988
+ self.output_fcs.append(output_fc)
989
+
990
+ self.task_names = [
991
+ 'Age', 'Attractive', 'Blurry', 'Chubby', 'Heavy Makeup', 'Gender', 'Oval Face', 'Pale Skin',
992
+ 'Smiling', 'Young',
993
+ 'Bald', 'Bangs', 'Black Hair', 'Blond Hair', 'Brown Hair', 'Gray Hair', 'Receding Hairline',
994
+ 'Straight Hair', 'Wavy Hair', 'Wearing Hat',
995
+ 'Arched Eyebrows', 'Bags Under Eyes', 'Bushy Eyebrows', 'Eyeglasses', 'Narrow Eyes', 'Big Nose',
996
+ 'Pointy Nose', 'High Cheekbones', 'Rosy Cheeks', 'Wearing Earrings',
997
+ 'Sideburns', r"Five O'Clock Shadow", 'Big Lips', 'Mouth Slightly Open', 'Mustache',
998
+ 'Wearing Lipstick', 'No Beard', 'Double Chin', 'Goatee', 'Wearing Necklace',
999
+ 'Wearing Necktie', 'Expression', 'Recognition'] # Total:43
1000
+
1001
+ self.output_type = output_type
1002
+
1003
+ self.apply(self._init_weights)
1004
+
1005
+ def set_output_type(self, output_type):
1006
+ self.output_type = output_type
1007
+
1008
+ def _init_weights(self, m):
1009
+ if isinstance(m, nn.Linear):
1010
+ trunc_normal_(m.weight, std=.02)
1011
+ if isinstance(m, nn.Linear) and m.bias is not None:
1012
+ nn.init.constant_(m.bias, 0)
1013
+ elif isinstance(m, nn.LayerNorm):
1014
+ nn.init.constant_(m.bias, 0)
1015
+ nn.init.constant_(m.weight, 1.0)
1016
+
1017
+ def forward(self, x, embedding):
1018
+
1019
+ outputs = []
1020
+
1021
+ k = 0
1022
+ for i in range(0, len(self.output_sizes)):
1023
+ for j in range(len(self.output_sizes[i])):
1024
+ output_fc = self.output_fcs[k]
1025
+ output = output_fc(x[i])
1026
+ outputs.append(output)
1027
+ k += 1
1028
+
1029
+ [gender,
1030
+ age, young,
1031
+ expression, smiling,
1032
+ attractive, blurry, chubby, heavy_makeup, oval_face, pale_skin,
1033
+ bald, bangs, black_hair, blond_hair, brown_hair, gray_hair, receding_hairline, straight_hair, wavy_hair,
1034
+ wearing_hat,
1035
+ arched_eyebrows, bags_under_eyes, bushy_eyebrows, eyeglasses, narrow_eyes,
1036
+ big_nose, pointy_nose,
1037
+ high_cheekbones, rosy_cheeks, wearing_earrings, sideburns,
1038
+ five_o_clock_shadow, big_lips, mouth_slightly_open, mustache, wearing_lipstick, no_beard,
1039
+ double_chin, goatee,
1040
+ wearing_necklace, wearing_necktie] = outputs
1041
+
1042
+ outputs = [age, attractive, blurry, chubby, heavy_makeup, gender, oval_face, pale_skin, smiling, young,
1043
+ bald, bangs, black_hair, blond_hair, brown_hair, gray_hair, receding_hairline,
1044
+ straight_hair, wavy_hair, wearing_hat,
1045
+ arched_eyebrows, bags_under_eyes, bushy_eyebrows, eyeglasses, narrow_eyes, big_nose,
1046
+ pointy_nose, high_cheekbones, rosy_cheeks, wearing_earrings,
1047
+ sideburns, five_o_clock_shadow, big_lips, mouth_slightly_open, mustache,
1048
+ wearing_lipstick, no_beard, double_chin, goatee, wearing_necklace,
1049
+ wearing_necktie, expression] # Total:42
1050
+
1051
+ outputs.append(embedding)
1052
+
1053
+ result = dict()
1054
+ for j in range(43):
1055
+ result[self.task_names[j]] = outputs[j]
1056
+
1057
+ if self.output_type == "Dict":
1058
+ return result
1059
+ elif self.output_type == "List":
1060
+ return outputs
1061
+ elif self.output_type == "Attribute":
1062
+ return outputs[1: 41]
1063
+ else:
1064
+ return result[self.output_type]
1065
+
1066
+
1067
+ class ModelBox(torch.nn.Module):
1068
+
1069
+ def __init__(self, backbone=None, fam=None, tss=None, om=None,
1070
+ feature="global", output_type="Dict"):
1071
+ super().__init__()
1072
+ self.backbone = backbone
1073
+ self.fam = fam
1074
+ self.tss = tss
1075
+ self.om = om
1076
+ self.output_type = output_type
1077
+ if self.om:
1078
+ self.om.set_output_type(self.output_type)
1079
+
1080
+ self.feature = feature
1081
+
1082
+ def set_output_type(self, output_type):
1083
+ self.output_type = output_type
1084
+ if self.om:
1085
+ self.om.set_output_type(self.output_type)
1086
+
1087
+
1088
+ def forward(self, x):
1089
+
1090
+ local_features, global_features, embedding = self.backbone(x)
1091
+
1092
+ if self.feature == "all":
1093
+ x = torch.cat([local_features, global_features], dim=1)
1094
+ elif self.feature == "global":
1095
+ x = global_features
1096
+ elif self.feature == "local":
1097
+ x = local_features
1098
+
1099
+ x = self.fam(x)
1100
+ x = self.tss(x)
1101
+
1102
+ x = self.om(x, embedding)
1103
+ return x
1104
+
1105
+ def build_model(cfg):
1106
+
1107
+ backbone = SwinTransformer(num_classes=cfg.embedding_size)
1108
+
1109
+ fam = FeatureAttentionModule(
1110
+ in_chans=cfg.fam_in_chans, kernel_size=cfg.fam_kernel_size,
1111
+ conv_shared=cfg.fam_conv_shared, conv_mode=cfg.fam_conv_mode,
1112
+ channel_attention=cfg.fam_channel_attention, spatial_attention=cfg.fam_spatial_attention,
1113
+ pooling=cfg.fam_pooling, la_num_list=cfg.fam_la_num_list)
1114
+ tss = TaskSpecificSubnets()
1115
+ om = OutputModule()
1116
+
1117
+ model = ModelBox(backbone=backbone, fam=fam, tss=tss, om=om, feature=cfg.fam_feature)
1118
+
1119
+ return model
1120
+
1121
+ class SwinFaceCfg:
1122
+ network = "swin_t"
1123
+ fam_kernel_size=3
1124
+ fam_in_chans=2112
1125
+ fam_conv_shared=False
1126
+ fam_conv_mode="split"
1127
+ fam_channel_attention="CBAM"
1128
+ fam_spatial_attention=None
1129
+ fam_pooling="max"
1130
+ fam_la_num_list=[2 for j in range(11)]
1131
+ fam_feature="all"
1132
+ fam = "3x3_2112_F_s_C_N_max"
1133
+ embedding_size = 512
1134
+
1135
+ @torch.no_grad()
1136
+ def load_model():
1137
+ cfg = SwinFaceCfg()
1138
+ weight = os.getcwd() + "/weights.pt"
1139
+ if not os.path.isfile(weight):
1140
+ gdown.download("https://drive.google.com/uc?export=download&id=1fi4IuuFV8NjnWm-CufdrhMKrkjxhSmjx", weight)
1141
+
1142
+ model = build_model(cfg)
1143
+ dict_checkpoint = torch.load(weight, map_location=torch.device('cpu'))
1144
+ model.backbone.load_state_dict(dict_checkpoint["state_dict_backbone"])
1145
+ model.fam.load_state_dict(dict_checkpoint["state_dict_fam"])
1146
+ model.tss.load_state_dict(dict_checkpoint["state_dict_tss"])
1147
+ model.om.load_state_dict(dict_checkpoint["state_dict_om"])
1148
+
1149
+ model.eval()
1150
+ return model
1151
+
1152
+
1153
+ def get_embeddings(model, images):
1154
+ embeddings = []
1155
+ for img in images:
1156
+ img = cv2.resize(np.array(img), (112, 112))
1157
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1158
+ img = np.transpose(img, (2, 0, 1))
1159
+ img = torch.from_numpy(img).unsqueeze(0).float()
1160
+ img.div_(255).sub_(0.5).div_(0.5)
1161
+ with torch.inference_mode():
1162
+ output = model(img)
1163
+ embeddings.append(output["Recognition"][0].numpy())
1164
+ return embeddings