|
import os |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from torchvision import transforms |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from training.detectors import DETECTOR |
|
import yaml |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model(model_name, config_path, weights_path): |
|
with open(config_path, 'r') as f: |
|
config = yaml.safe_load(f) |
|
|
|
config['model_name'] = model_name |
|
|
|
model_class = DETECTOR[model_name] |
|
model = model_class(config).to(device) |
|
|
|
checkpoint = torch.load(weights_path, map_location=device) |
|
model.load_state_dict(checkpoint, strict=True) |
|
model.eval() |
|
return model |
|
|
|
|
|
def preprocess_video(video_path, output_dir, frame_num=32): |
|
os.makedirs(output_dir, exist_ok=True) |
|
frames_dir = os.path.join(output_dir, "frames") |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
frame_indices = np.linspace(0, total_frames - 1, frame_num, dtype=int) |
|
|
|
|
|
frames = [] |
|
for idx in frame_indices: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
ret, frame = cap.read() |
|
if ret: |
|
frame_path = os.path.join(frames_dir, f"frame_{idx:04d}.png") |
|
cv2.imwrite(frame_path, frame) |
|
frames.append(frame_path) |
|
|
|
cap.release() |
|
return frames |
|
|
|
|
|
def infer_video(video_path, model, device): |
|
output_dir = "temp_video_frames" |
|
frames = preprocess_video(video_path, output_dir) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
probs = [] |
|
for frame_path in frames: |
|
frame = Image.open(frame_path).convert("RGB") |
|
frame = transform(frame).unsqueeze(0).to(device) |
|
|
|
data_dict = { |
|
"image": frame, |
|
"label": torch.tensor([0]).to(device), |
|
"label_spe": torch.tensor([0]).to(device), |
|
} |
|
|
|
with torch.no_grad(): |
|
pred_dict = model(data_dict, inference=True) |
|
|
|
logits = pred_dict["cls"] |
|
prob = torch.softmax(logits, dim=1)[:, 1].item() |
|
probs.append(prob) |
|
|
|
avg_prob = np.mean(probs) |
|
prediction = "Fake" if avg_prob > 0.5 else "Real" |
|
return prediction, avg_prob |
|
|
|
|
|
def main(video_filename, model_name): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
config_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/config/detector/{model_name}.yaml" |
|
weights_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/weights/{model_name}_best.pth" |
|
|
|
if not os.path.exists(config_path): |
|
print(f"Error: Config file for model '{model_name}' not found at {config_path}.") |
|
return |
|
if not os.path.exists(weights_path): |
|
print(f"Error: Weights file for model '{model_name}' not found at {weights_path}.") |
|
return |
|
|
|
model = load_model(model_name, config_path, weights_path) |
|
|
|
video_path = os.path.join(os.getcwd(), video_filename) |
|
if not os.path.exists(video_path): |
|
print(f"Error: Video file '{video_filename}' not found in the current directory.") |
|
return |
|
|
|
prediction, confidence = infer_video(video_path, model, device) |
|
print(f"Model: {model_name}") |
|
print(f"Prediction: {prediction} (Confidence: {confidence:.4f})") |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
if len(sys.argv) != 3: |
|
print("Usage: python inference_script.py <video_filename> <model_name>") |
|
print("Available models: xception, meso4, meso4Inception, efficientnetb4, ucf, etc.") |
|
else: |
|
video_filename = sys.argv[1] |
|
model_name = sys.argv[2] |
|
main(video_filename, model_name) |