File size: 4,099 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from training.detectors import DETECTOR
import yaml
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load the model
def load_model(model_name, config_path, weights_path):
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
config['model_name'] = model_name
model_class = DETECTOR[model_name]
model = model_class(config).to(device)
checkpoint = torch.load(weights_path, map_location=device)
model.load_state_dict(checkpoint, strict=True)
model.eval()
return model
# preprocess a single video
def preprocess_video(video_path, output_dir, frame_num=32):
os.makedirs(output_dir, exist_ok=True)
frames_dir = os.path.join(output_dir, "frames")
os.makedirs(frames_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = np.linspace(0, total_frames - 1, frame_num, dtype=int)
# extract frames
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_path = os.path.join(frames_dir, f"frame_{idx:04d}.png")
cv2.imwrite(frame_path, frame)
frames.append(frame_path)
cap.release()
return frames
# inference on a single video
def infer_video(video_path, model, device):
output_dir = "temp_video_frames"
frames = preprocess_video(video_path, output_dir)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
probs = []
for frame_path in frames:
frame = Image.open(frame_path).convert("RGB")
frame = transform(frame).unsqueeze(0).to(device)
data_dict = {
"image": frame,
"label": torch.tensor([0]).to(device), # Dummy label
"label_spe": torch.tensor([0]).to(device), # Dummy specific label
}
with torch.no_grad():
pred_dict = model(data_dict, inference=True)
logits = pred_dict["cls"] # Shape: [batch_size, num_classes]
prob = torch.softmax(logits, dim=1)[:, 1].item() # Probability of being "fake"
probs.append(prob)
avg_prob = np.mean(probs)
prediction = "Fake" if avg_prob > 0.5 else "Real"
return prediction, avg_prob
# main function for terminal-based inference
def main(video_filename, model_name):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/config/detector/{model_name}.yaml"
weights_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/weights/{model_name}_best.pth"
if not os.path.exists(config_path):
print(f"Error: Config file for model '{model_name}' not found at {config_path}.")
return
if not os.path.exists(weights_path):
print(f"Error: Weights file for model '{model_name}' not found at {weights_path}.")
return
model = load_model(model_name, config_path, weights_path)
video_path = os.path.join(os.getcwd(), video_filename)
if not os.path.exists(video_path):
print(f"Error: Video file '{video_filename}' not found in the current directory.")
return
prediction, confidence = infer_video(video_path, model, device)
print(f"Model: {model_name}")
print(f"Prediction: {prediction} (Confidence: {confidence:.4f})")
if __name__ == "__main__":
import sys
if len(sys.argv) != 3:
print("Usage: python inference_script.py <video_filename> <model_name>")
print("Available models: xception, meso4, meso4Inception, efficientnetb4, ucf, etc.")
else:
video_filename = sys.argv[1]
model_name = sys.argv[2]
main(video_filename, model_name) |