Kimata commited on
Commit
bfe3d1b
·
1 Parent(s): 87811d3

solution 2

Browse files
.gitattributes CHANGED
@@ -1,5 +1,5 @@
1
  <<<<<<< HEAD
2
- checkpoints/model_best.pt filter=lfs diff=lfs merge=lfs -text
3
  =======
4
  *.7z filter=lfs diff=lfs merge=lfs -text
5
  *.arrow filter=lfs diff=lfs merge=lfs -text
@@ -37,3 +37,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
37
  *.zst filter=lfs diff=lfs merge=lfs -text
38
  *tfevents* filter=lfs diff=lfs merge=lfs -text
39
  >>>>>>> 10b34e9e01a793df83cca1499ece5c6b29f10a90
 
 
1
  <<<<<<< HEAD
2
+ checkpoints/model_best.pt filter=lfs diff=lfs merge=lfs -text
3
  =======
4
  *.7z filter=lfs diff=lfs merge=lfs -text
5
  *.arrow filter=lfs diff=lfs merge=lfs -text
 
37
  *.zst filter=lfs diff=lfs merge=lfs -text
38
  *tfevents* filter=lfs diff=lfs merge=lfs -text
39
  >>>>>>> 10b34e9e01a793df83cca1499ece5c6b29f10a90
40
+ checkpoints/model.pth filter=lfs diff=lfs merge=lfs -text
__pycache__/inference_2.cpython-39.pyc ADDED
Binary file (5.56 kB). View file
 
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import inference
3
 
4
 
5
  title="Multimodal deepfake detector"
 
1
  import gradio as gr
2
+ import inference_2 as inference
3
 
4
 
5
  title="Multimodal deepfake detector"
checkpoints/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd7e7092a26ba6b2927a05150d25f03fb19e4562006835cfa585a055b419f2f2
3
+ size 604878654
inference_2.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ from models.TMC import ETMC
8
+ from models import image
9
+
10
+ #Set random seed for reproducibility.
11
+ torch.manual_seed(42)
12
+
13
+
14
+ # Define the audio_args dictionary
15
+ audio_args = {
16
+ 'nb_samp': 64600,
17
+ 'first_conv': 1024,
18
+ 'in_channels': 1,
19
+ 'filts': [20, [20, 20], [20, 128], [128, 128]],
20
+ 'blocks': [2, 4],
21
+ 'nb_fc_node': 1024,
22
+ 'gru_node': 1024,
23
+ 'nb_gru_layer': 3,
24
+ 'nb_classes': 2
25
+ }
26
+
27
+
28
+ def get_args(parser):
29
+ parser.add_argument("--batch_size", type=int, default=8)
30
+ parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
31
+ parser.add_argument("--LOAD_SIZE", type=int, default=256)
32
+ parser.add_argument("--FINE_SIZE", type=int, default=224)
33
+ parser.add_argument("--dropout", type=float, default=0.2)
34
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
35
+ parser.add_argument("--hidden", nargs="*", type=int, default=[])
36
+ parser.add_argument("--hidden_sz", type=int, default=768)
37
+ parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
38
+ parser.add_argument("--img_hidden_sz", type=int, default=1024)
39
+ parser.add_argument("--include_bn", type=int, default=True)
40
+ parser.add_argument("--lr", type=float, default=1e-4)
41
+ parser.add_argument("--lr_factor", type=float, default=0.3)
42
+ parser.add_argument("--lr_patience", type=int, default=10)
43
+ parser.add_argument("--max_epochs", type=int, default=500)
44
+ parser.add_argument("--n_workers", type=int, default=12)
45
+ parser.add_argument("--name", type=str, default="MMDF")
46
+ parser.add_argument("--num_image_embeds", type=int, default=1)
47
+ parser.add_argument("--patience", type=int, default=20)
48
+ parser.add_argument("--savedir", type=str, default="./savepath/")
49
+ parser.add_argument("--seed", type=int, default=1)
50
+ parser.add_argument("--n_classes", type=int, default=2)
51
+ parser.add_argument("--annealing_epoch", type=int, default=10)
52
+ parser.add_argument("--device", type=str, default='cpu')
53
+ parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
54
+ parser.add_argument("--freeze_image_encoder", type=bool, default = False)
55
+ parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
56
+ parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
57
+ parser.add_argument("--augment_dataset", type = bool, default = True)
58
+
59
+ for key, value in audio_args.items():
60
+ parser.add_argument(f"--{key}", type=type(value), default=value)
61
+
62
+ def model_summary(args):
63
+ '''Prints the model summary.'''
64
+ model = ETMC(args)
65
+
66
+ for name, layer in model.named_modules():
67
+ print(name, layer)
68
+
69
+ def load_multimodal_model(args):
70
+ '''Load multimodal model'''
71
+ model = ETMC(args)
72
+ ckpt = torch.load('checkpoints\\model.pth', map_location = torch.device('cpu'))
73
+ model.load_state_dict(ckpt,strict = True)
74
+ model.eval()
75
+ return model
76
+
77
+ def load_img_modality_model(args):
78
+ '''Loads image modality model.'''
79
+ rgb_encoder = image.ImageEncoder(args)
80
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
81
+ rgb_encoder.load_state_dict(ckpt, strict = False)
82
+ rgb_encoder.eval()
83
+ return rgb_encoder
84
+
85
+ def load_spec_modality_model(args):
86
+ spec_encoder = image.RawNet(args)
87
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
88
+ spec_encoder.load_state_dict(ckpt, strict = False)
89
+ spec_encoder.eval()
90
+ return spec_encoder
91
+
92
+
93
+ #Load models.
94
+ parser = argparse.ArgumentParser(description="Train Models")
95
+ get_args(parser)
96
+ args, remaining_args = parser.parse_known_args()
97
+ assert remaining_args == [], remaining_args
98
+
99
+ # multimodal = load_multimodal_model(args)
100
+ spec_model = load_spec_modality_model(args)
101
+
102
+ # print(f"Spec model is: {spec_model}")
103
+
104
+ img_model = load_img_modality_model(args)
105
+
106
+ # print(f"Image model is: {img_model}")
107
+
108
+ # spec_in = torch.randn(1, 10_000)
109
+ # rgb_in = torch.randn([1, 3, 256, 256])
110
+
111
+ # rgb_out = img_model(rgb_in)
112
+ # spec_out = spec_model(spec_in)
113
+
114
+ # print(f"Img input shape is: {rgb_in.shape}, output shape: {rgb_out}")
115
+ # print(f"Spec input shape is: {spec_in.shape}, output shape is: {spec_out.shape} output: {spec_out}")
116
+
117
+ def preprocess_img(face):
118
+ face = face / 255
119
+ face = cv2.resize(face, (256, 256))
120
+ face = face.transpose(2, 0, 1) #(W, H, C) -> (C, W, H)
121
+ face_pt = torch.unsqueeze(torch.Tensor(face), dim = 0)
122
+ return face_pt
123
+
124
+ def preprocess_audio(audio_file):
125
+ audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim = 0)
126
+ return audio_pt
127
+
128
+ def deepfakes_spec_predict(input_audio):
129
+ x, _ = input_audio
130
+ audio = preprocess_audio(x)
131
+ spec_grads = spec_model.forward(audio)
132
+ spec_grads_inv = np.exp(spec_grads.cpu().numpy().squeeze())
133
+
134
+ # multimodal_grads = multimodal.spec_depth[0].forward(spec_grads)
135
+
136
+ # out = nn.Softmax()(multimodal_grads)
137
+ # max = torch.argmax(out, dim = -1) #Index of the max value in the tensor.
138
+ # max_value = out[max] #Actual value of the tensor.
139
+ max_value = np.argmax(spec_grads_inv)
140
+
141
+ if max_value > 0.5:
142
+ preds = round(100 - (max_value*100), 3)
143
+ text2 = f"The audio is REAL."
144
+
145
+ else:
146
+ preds = round(max_value*100, 3)
147
+ text2 = f"The audio is FAKE."
148
+
149
+ return text2
150
+
151
+ def deepfakes_image_predict(input_image):
152
+ face = preprocess_img(input_image)
153
+
154
+ img_grads = img_model.forward(face)
155
+ img_grads = img_grads.cpu().detach().numpy()
156
+ img_grads_np = np.squeeze(img_grads)
157
+
158
+ if img_grads_np > 0.5:
159
+ preds = round(100 - (img_grads_np * 100), 3)
160
+ text2 = f"The image is REAL."
161
+
162
+ else:
163
+ preds = round(img_grads_np * 100, 3)
164
+ text2 = f"The image is FAKE."
165
+
166
+ return text2
167
+
168
+
169
+ def preprocess_video(input_video, n_frames = 3):
170
+ v_cap = cv2.VideoCapture(input_video)
171
+ v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
172
+
173
+ # Pick 'n_frames' evenly spaced frames to sample
174
+ if n_frames is None:
175
+ sample = np.arange(0, v_len)
176
+ else:
177
+ sample = np.linspace(0, v_len - 1, n_frames).astype(int)
178
+
179
+ #Loop through frames.
180
+ frames = []
181
+ for j in range(v_len):
182
+ success = v_cap.grab()
183
+ if j in sample:
184
+ # Load frame
185
+ success, frame = v_cap.retrieve()
186
+ if not success:
187
+ continue
188
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
189
+ frame = preprocess_img(frame)
190
+ frames.append(frame)
191
+ v_cap.release()
192
+ return frames
193
+
194
+
195
+ def deepfakes_video_predict(input_video):
196
+ '''Perform inference on a video.'''
197
+ video_frames = preprocess_video(input_video)
198
+ grads_list = []
199
+
200
+ for face in video_frames:
201
+ # face = preprocess_img(face)
202
+
203
+ img_grads = img_model.forward(face)
204
+ img_grads = img_grads.cpu().detach().numpy()
205
+ img_grads_np = np.squeeze(img_grads)
206
+ grads_list.append(img_grads_np)
207
+
208
+ grads_list_mean = np.mean(grads_list)
209
+
210
+ if grads_list_mean > 0.5:
211
+ res = round(grads_list_mean * 100, 3)
212
+ text = f"The video is REAL."
213
+ else:
214
+ res = round(100 - (grads_list_mean * 100), 3)
215
+ text = f"The video is FAKE."
216
+ return text
217
+
models/__pycache__/image.cpython-39.pyc CHANGED
Binary files a/models/__pycache__/image.cpython-39.pyc and b/models/__pycache__/image.cpython-39.pyc differ
 
models/image.py CHANGED
@@ -14,7 +14,8 @@ class ImageEncoder(nn.Module):
14
  self.device = args.device
15
  self.args = args
16
  self.flatten = nn.Flatten()
17
- self.fc = nn.Linear(in_features=2560, out_features = 1024)
 
18
  self.pretrained_image_encoder = args.pretrained_image_encoder
19
  self.freeze_image_encoder = args.freeze_image_encoder
20
 
@@ -22,24 +23,25 @@ class ImageEncoder(nn.Module):
22
  self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
23
 
24
  else:
25
- self.pretrained_ckpt = torch.load('DFDT TMC/pretrained/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
26
  self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
27
 
28
  self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
29
  print("Loading pretrained image encoder...")
30
- self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=False)
31
  print("Loaded pretrained image encoder.")
32
 
33
  if self.freeze_image_encoder == True:
34
  for idx, param in self.model.named_parameters():
35
  param.requires_grad = False
36
 
37
- self.model.fc = nn.Identity()
38
 
39
  def forward(self, x):
40
  x = self.model(x)
41
- x = self.flatten(x)
42
- out = self.fc(x)
 
43
  return out
44
 
45
 
@@ -84,7 +86,12 @@ class RawNet(nn.Module):
84
  hidden_size = args.gru_node,
85
  num_layers = args.nb_gru_layer,
86
  batch_first = True)
87
-
 
 
 
 
 
88
 
89
  self.sig = nn.Sigmoid()
90
  self.logsoftmax = nn.LogSoftmax(dim=1)
@@ -93,9 +100,9 @@ class RawNet(nn.Module):
93
 
94
  if self.pretrained_audio_encoder == True:
95
  print("Loading pretrained audio encoder")
96
- ckpt = torch.load('DFDT TMC/pretrained/RawNet.pth', map_location = torch.device(self.device))
97
  print("Loaded pretrained audio encoder")
98
- self.load_state_dict(ckpt, strict = False)
99
 
100
  if self.freeze_audio_encoder:
101
  for param in self.parameters():
@@ -155,7 +162,10 @@ class RawNet(nn.Module):
155
  x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
156
  self.gru.flatten_parameters()
157
  x, _ = self.gru(x)
158
- output = x[:,-1,:]
 
 
 
159
 
160
  return output
161
 
 
14
  self.device = args.device
15
  self.args = args
16
  self.flatten = nn.Flatten()
17
+ self.sigmoid = nn.Sigmoid()
18
+ # self.fc = nn.Linear(in_features=2560, out_features = 2)
19
  self.pretrained_image_encoder = args.pretrained_image_encoder
20
  self.freeze_image_encoder = args.freeze_image_encoder
21
 
 
23
  self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
24
 
25
  else:
26
+ self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
27
  self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
28
 
29
  self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
30
  print("Loading pretrained image encoder...")
31
+ self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True)
32
  print("Loaded pretrained image encoder.")
33
 
34
  if self.freeze_image_encoder == True:
35
  for idx, param in self.model.named_parameters():
36
  param.requires_grad = False
37
 
38
+ # self.model.fc = nn.Identity()
39
 
40
  def forward(self, x):
41
  x = self.model(x)
42
+ out = self.sigmoid(x)
43
+ # x = self.flatten(x)
44
+ # out = self.fc(x)
45
  return out
46
 
47
 
 
86
  hidden_size = args.gru_node,
87
  num_layers = args.nb_gru_layer,
88
  batch_first = True)
89
+
90
+ self.fc1_gru = nn.Linear(in_features = args.gru_node,
91
+ out_features = args.nb_fc_node)
92
+
93
+ self.fc2_gru = nn.Linear(in_features = args.nb_fc_node,
94
+ out_features = args.nb_classes ,bias=True)
95
 
96
  self.sig = nn.Sigmoid()
97
  self.logsoftmax = nn.LogSoftmax(dim=1)
 
100
 
101
  if self.pretrained_audio_encoder == True:
102
  print("Loading pretrained audio encoder")
103
+ ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device))
104
  print("Loaded pretrained audio encoder")
105
+ self.load_state_dict(ckpt, strict = True)
106
 
107
  if self.freeze_audio_encoder:
108
  for param in self.parameters():
 
162
  x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
163
  self.gru.flatten_parameters()
164
  x, _ = self.gru(x)
165
+ x = x[:,-1,:]
166
+ x = self.fc1_gru(x)
167
+ x = self.fc2_gru(x)
168
+ output=self.logsoftmax(x)
169
 
170
  return output
171