Mojmir commited on
Commit
71d7efc
·
unverified ·
1 Parent(s): ffe23f3

fixing secure mode

Browse files
Files changed (1) hide show
  1. app.py +7 -15
app.py CHANGED
@@ -15,6 +15,7 @@ class_names = ['Fake', 'Real'] # Fix the incorrect mapping
15
 
16
  # Load the trained model
17
  def load_model(model_path, device):
 
18
  model = resnet18_custom(weights=None)
19
  num_ftrs = model.fc.in_features
20
  model.fc = nn.Linear(num_ftrs, len(class_names)) # Assuming 2 classes: Fake and Real
@@ -36,6 +37,9 @@ def load_secure_model(model):
36
  )
37
  return secure_model
38
 
 
 
 
39
  # Image preprocessing (match with the transforms used during training)
40
  data_transform = transforms.Compose([
41
  transforms.Resize((224, 224)),
@@ -44,17 +48,7 @@ data_transform = transforms.Compose([
44
 
45
  # Prediction function
46
  def predict(image, mode):
47
- # Device configuration
48
- device = torch.device(
49
- "cuda:0" if torch.cuda.is_available() else
50
- "mps" if torch.backends.mps.is_available() else
51
- "cpu"
52
- )
53
-
54
- print(f"Device: {device}")
55
- # Load model
56
- model_path = 'models/deepfake_detection_model.pth'
57
- model = load_model(model_path, device)
58
 
59
  # Apply transformations to the input image
60
  image = Image.open(image).convert('RGB')
@@ -69,11 +63,9 @@ def predict(image, mode):
69
  outputs = model(image)
70
  elif mode == "Secure":
71
  # Secure mode (e.g., running multiple times for higher confidence)
72
- secure_model = load_secure_model(model)
73
  detached_input = image.detach().numpy()
74
- outputs = secure_model(detached_input, fhe="simulate")
75
 
76
- print(outputs)
77
  _, preds = torch.max(outputs, 1)
78
  elapsed_time = time.time() - start_time
79
 
@@ -96,4 +88,4 @@ iface = gr.Interface(
96
  )
97
 
98
  if __name__ == "__main__":
99
- iface.launch(share=True)
 
15
 
16
  # Load the trained model
17
  def load_model(model_path, device):
18
+ print("load_model")
19
  model = resnet18_custom(weights=None)
20
  num_ftrs = model.fc.in_features
21
  model.fc = nn.Linear(num_ftrs, len(class_names)) # Assuming 2 classes: Fake and Real
 
37
  )
38
  return secure_model
39
 
40
+ model = load_model('models/deepfake_detection_model.pth', 'cpu')
41
+ secure_model = load_secure_model(model)
42
+
43
  # Image preprocessing (match with the transforms used during training)
44
  data_transform = transforms.Compose([
45
  transforms.Resize((224, 224)),
 
48
 
49
  # Prediction function
50
  def predict(image, mode):
51
+ device = 'cpu'
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Apply transformations to the input image
54
  image = Image.open(image).convert('RGB')
 
63
  outputs = model(image)
64
  elif mode == "Secure":
65
  # Secure mode (e.g., running multiple times for higher confidence)
 
66
  detached_input = image.detach().numpy()
67
+ outputs = torch.from_numpy(secure_model.forward(detached_input, fhe="simulate"))
68
 
 
69
  _, preds = torch.max(outputs, 1)
70
  elapsed_time = time.time() - start_time
71
 
 
88
  )
89
 
90
  if __name__ == "__main__":
91
+ iface.launch(share=True)