|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
from scipy.stats import norm |
|
import os |
|
|
|
def create_transform(config, training=False): |
|
"""Create transform pipeline based on config""" |
|
|
|
required_keys = { |
|
"image_size", |
|
"normalization_mean", |
|
"normalization_std" |
|
} |
|
|
|
|
|
if training: |
|
required_keys.update({ |
|
"flip_probability", |
|
"rotation_degrees", |
|
"brightness_jitter", |
|
"contrast_jitter", |
|
"saturation_jitter", |
|
"hue_jitter", |
|
"crop_scale_min", |
|
"crop_scale_max" |
|
}) |
|
|
|
missing_keys = required_keys - set(config.keys()) |
|
if missing_keys: |
|
raise ValueError(f"Missing required config keys: {missing_keys}") |
|
|
|
|
|
transform_list = [ |
|
transforms.ToPILImage(), |
|
transforms.Resize((config["image_size"], config["image_size"])) |
|
] |
|
|
|
|
|
if training: |
|
transform_list.extend([ |
|
transforms.RandomHorizontalFlip(p=config["flip_probability"]), |
|
transforms.RandomRotation(config["rotation_degrees"]), |
|
transforms.ColorJitter( |
|
brightness=config["brightness_jitter"], |
|
contrast=config["contrast_jitter"], |
|
saturation=config["saturation_jitter"], |
|
hue=config["hue_jitter"] |
|
), |
|
transforms.RandomResizedCrop( |
|
config["image_size"], |
|
scale=(config["crop_scale_min"], config["crop_scale_max"]) |
|
) |
|
]) |
|
|
|
|
|
transform_list.extend([ |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=config["normalization_mean"], |
|
std=config["normalization_std"] |
|
) |
|
]) |
|
|
|
return transforms.Compose(transform_list) |
|
|
|
def extract_frames(video_path: str, config: dict, transform) -> tuple[torch.Tensor, bool]: |
|
"""Extract and process frames from video using Gaussian sampling |
|
Returns: |
|
tuple: (frames tensor, success boolean) |
|
""" |
|
|
|
required_keys = {"max_frames", "sigma"} |
|
missing_keys = required_keys - set(config.keys()) |
|
if missing_keys: |
|
raise ValueError(f"Missing required config keys for frame extraction: {missing_keys}") |
|
|
|
frames = [] |
|
success = True |
|
|
|
if not os.path.exists(video_path): |
|
print(f"File not found: {video_path}") |
|
return None, False |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print(f"Failed to open video: {video_path}") |
|
return None, False |
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
if total_frames == 0: |
|
print(f"Video has no frames: {video_path}") |
|
cap.release() |
|
return None, False |
|
|
|
|
|
x = np.linspace(0, 1, total_frames) |
|
probabilities = norm.pdf(x, loc=0.5, scale=config["sigma"]) |
|
probabilities /= probabilities.sum() |
|
|
|
|
|
frame_indices = np.sort(np.random.choice( |
|
total_frames, |
|
size=min(config["max_frames"], total_frames), |
|
replace=False, |
|
p=probabilities |
|
)) |
|
|
|
for frame_idx in frame_indices: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
ret, frame = cap.read() |
|
if not ret: |
|
print(f"Failed to read frame {frame_idx} from video: {video_path}") |
|
success = False |
|
break |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
if transform: |
|
frame = transform(frame) |
|
frames.append(frame) |
|
|
|
cap.release() |
|
|
|
if not frames: |
|
print(f"No frames extracted from video: {video_path}") |
|
return None, False |
|
|
|
|
|
while len(frames) < config["max_frames"]: |
|
frames.append(torch.zeros_like(frames[0])) |
|
|
|
return torch.stack(frames), success |