Live2Diff / live2diff /image_filter.py
leoxing1996
add demo
d16b52d
raw
history blame
1.52 kB
import random
from typing import Optional
import torch
class SimilarImageFilter:
def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
self.threshold = threshold
self.prev_tensor = None
self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
self.max_skip_frame = max_skip_frame
self.skip_count = 0
def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]:
if self.prev_tensor is None:
self.prev_tensor = x.detach().clone()
return x
else:
cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item()
sample = random.uniform(0, 1)
if self.threshold >= 1:
skip_prob = 0
else:
skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold))
# not skip frame
if skip_prob < sample:
self.prev_tensor = x.detach().clone()
return x
# skip frame
else:
if self.skip_count > self.max_skip_frame:
self.skip_count = 0
self.prev_tensor = x.detach().clone()
return x
else:
self.skip_count += 1
return None
def set_threshold(self, threshold: float) -> None:
self.threshold = threshold
def set_max_skip_frame(self, max_skip_frame: float) -> None:
self.max_skip_frame = max_skip_frame