Spaces:
Running
Running
from typing import Callable, Optional, Union | |
import numpy as np | |
from torch import Tensor | |
from comfy.model_base import BaseModel | |
from .utils_motion import get_sorted_list_via_attr | |
class ContextFuseMethod: | |
FLAT = "flat" | |
PYRAMID = "pyramid" | |
RELATIVE = "relative" | |
LIST = [PYRAMID, FLAT] | |
LIST_STATIC = [PYRAMID, RELATIVE, FLAT] | |
class ContextType: | |
UNIFORM_WINDOW = "uniform window" | |
class ContextOptions: | |
def __init__(self, context_length: int=None, context_stride: int=None, context_overlap: int=None, | |
context_schedule: str=None, closed_loop: bool=False, fuse_method: str=ContextFuseMethod.FLAT, | |
use_on_equal_length: bool=False, view_options: 'ContextOptions'=None, | |
start_percent=0.0, guarantee_steps=1): | |
# permanent settings | |
self.context_length = context_length | |
self.context_stride = context_stride | |
self.context_overlap = context_overlap | |
self.context_schedule = context_schedule | |
self.closed_loop = closed_loop | |
self.fuse_method = fuse_method | |
self.sync_context_to_pe = False # this feature is likely bad and stay unused, so I might remove this | |
self.use_on_equal_length = use_on_equal_length | |
self.view_options = view_options.clone() if view_options else view_options | |
# scheduling | |
self.start_percent = float(start_percent) | |
self.start_t = 999999999.9 | |
self.guarantee_steps = guarantee_steps | |
# temporary vars | |
self._step: int = 0 | |
def step(self): | |
return self._step | |
def step(self, value: int): | |
self._step = value | |
if self.view_options: | |
self.view_options.step = value | |
def clone(self): | |
n = ContextOptions(context_length=self.context_length, context_stride=self.context_stride, | |
context_overlap=self.context_overlap, context_schedule=self.context_schedule, | |
closed_loop=self.closed_loop, fuse_method=self.fuse_method, | |
use_on_equal_length=self.use_on_equal_length, view_options=self.view_options, | |
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) | |
n.start_t = self.start_t | |
return n | |
class ContextOptionsGroup: | |
def __init__(self): | |
self.contexts: list[ContextOptions] = [] | |
self._current_context: ContextOptions = None | |
self._current_used_steps: int = 0 | |
self._current_index: int = 0 | |
self.step = 0 | |
def reset(self): | |
self._current_context = None | |
self._current_used_steps = 0 | |
self._current_index = 0 | |
self.step = 0 | |
self._set_first_as_current() | |
def default(cls): | |
def_context = ContextOptions() | |
new_group = ContextOptionsGroup() | |
new_group.add(def_context) | |
return new_group | |
def add(self, context: ContextOptions): | |
# add to end of list, then sort | |
self.contexts.append(context) | |
self.contexts = get_sorted_list_via_attr(self.contexts, "start_percent") | |
self._set_first_as_current() | |
def add_to_start(self, context: ContextOptions): | |
# add to start of list, then sort | |
self.contexts.insert(0, context) | |
self.contexts = get_sorted_list_via_attr(self.contexts, "start_percent") | |
self._set_first_as_current() | |
def has_index(self, index: int) -> int: | |
return index >=0 and index < len(self.contexts) | |
def is_empty(self) -> bool: | |
return len(self.contexts) == 0 | |
def clone(self): | |
cloned = ContextOptionsGroup() | |
for context in self.contexts: | |
cloned.contexts.append(context) | |
cloned._set_first_as_current() | |
return cloned | |
def initialize_timesteps(self, model: BaseModel): | |
for context in self.contexts: | |
context.start_t = model.model_sampling.percent_to_sigma(context.start_percent) | |
def prepare_current_context(self, t: Tensor): | |
curr_t: float = t[0] | |
prev_index = self._current_index | |
# if met guaranteed steps, look for next context in case need to switch | |
if self._current_used_steps >= self._current_context.guarantee_steps: | |
# if has next index, loop through and see if need to switch | |
if self.has_index(self._current_index+1): | |
for i in range(self._current_index+1, len(self.contexts)): | |
eval_c = self.contexts[i] | |
# check if start_t is greater or equal to curr_t | |
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling | |
if eval_c.start_t >= curr_t: | |
self._current_index = i | |
self._current_context = eval_c | |
self._current_used_steps = 0 | |
# if guarantee_steps greater than zero, stop searching for other keyframes | |
if self._current_context.guarantee_steps > 0: | |
break | |
# if eval_c is outside the percent range, stop looking further | |
else: | |
break | |
# update steps current context is used | |
self._current_used_steps += 1 | |
def _set_first_as_current(self): | |
if len(self.contexts) > 0: | |
self._current_context = self.contexts[0] | |
# properties shadow those of ContextOptions | |
def context_length(self): | |
return self._current_context.context_length | |
def context_overlap(self): | |
return self._current_context.context_overlap | |
def context_stride(self): | |
return self._current_context.context_stride | |
def context_schedule(self): | |
return self._current_context.context_schedule | |
def closed_loop(self): | |
return self._current_context.closed_loop | |
def fuse_method(self): | |
return self._current_context.fuse_method | |
def use_on_equal_length(self): | |
return self._current_context.use_on_equal_length | |
def view_options(self): | |
return self._current_context.view_options | |
class ContextSchedules: | |
UNIFORM_LOOPED = "looped_uniform" | |
UNIFORM_STANDARD = "standard_uniform" | |
STATIC_STANDARD = "standard_static" | |
BATCHED = "batched" | |
VIEW_AS_CONTEXT = "view_as_context" | |
LEGACY_UNIFORM_LOOPED = "uniform" | |
LEGACY_UNIFORM_SCHEDULE_LIST = [LEGACY_UNIFORM_LOOPED] | |
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py | |
def create_windows_uniform_looped(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): | |
windows = [] | |
if num_frames < opts.context_length: | |
windows.append(list(range(num_frames))) | |
return windows | |
context_stride = min(opts.context_stride, int(np.ceil(np.log2(num_frames / opts.context_length))) + 1) | |
# obtain uniform windows as normal, looping and all | |
for context_step in 1 << np.arange(context_stride): | |
pad = int(round(num_frames * ordered_halving(opts.step))) | |
for j in range( | |
int(ordered_halving(opts.step) * context_step) + pad, | |
num_frames + pad + (0 if opts.closed_loop else -opts.context_overlap), | |
(opts.context_length * context_step - opts.context_overlap), | |
): | |
windows.append([e % num_frames for e in range(j, j + opts.context_length * context_step, context_step)]) | |
return windows | |
def create_windows_uniform_standard(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): | |
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning; | |
# instead, they get shifted to the corresponding end of the frames. | |
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped. | |
windows = [] | |
if num_frames <= opts.context_length: | |
windows.append(list(range(num_frames))) | |
return windows | |
context_stride = min(opts.context_stride, int(np.ceil(np.log2(num_frames / opts.context_length))) + 1) | |
# first, obtain uniform windows as normal, looping and all | |
for context_step in 1 << np.arange(context_stride): | |
pad = int(round(num_frames * ordered_halving(opts.step))) | |
for j in range( | |
int(ordered_halving(opts.step) * context_step) + pad, | |
num_frames + pad + (-opts.context_overlap), | |
(opts.context_length * context_step - opts.context_overlap), | |
): | |
windows.append([e % num_frames for e in range(j, j + opts.context_length * context_step, context_step)]) | |
# now that windows are created, shift any windows that loop, and delete duplicate windows | |
delete_idxs = [] | |
win_i = 0 | |
while win_i < len(windows): | |
# if window is rolls over itself, need to shift it | |
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) | |
if is_roll: | |
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides | |
shift_window_to_end(windows[win_i], num_frames=num_frames) | |
# check if next window (cyclical) is missing roll_val | |
if roll_val not in windows[(win_i+1) % len(windows)]: | |
# need to insert new window here - just insert window starting at roll_val | |
windows.insert(win_i+1, list(range(roll_val, roll_val + opts.context_length))) | |
# delete window if it's not unique | |
for pre_i in range(0, win_i): | |
if windows[win_i] == windows[pre_i]: | |
delete_idxs.append(win_i) | |
break | |
win_i += 1 | |
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation | |
delete_idxs.reverse() | |
for i in delete_idxs: | |
windows.pop(i) | |
return windows | |
def create_windows_static_standard(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): | |
windows = [] | |
if num_frames <= opts.context_length: | |
windows.append(list(range(num_frames))) | |
return windows | |
# always return the same set of windows | |
delta = opts.context_length - opts.context_overlap | |
for start_idx in range(0, num_frames, delta): | |
# if past the end of frames, move start_idx back to allow same context_length | |
ending = start_idx + opts.context_length | |
if ending >= num_frames: | |
final_delta = ending - num_frames | |
final_start_idx = start_idx - final_delta | |
windows.append(list(range(final_start_idx, final_start_idx + opts.context_length))) | |
break | |
windows.append(list(range(start_idx, start_idx + opts.context_length))) | |
return windows | |
def create_windows_batched(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): | |
windows = [] | |
if num_frames <= opts.context_length: | |
windows.append(list(range(num_frames))) | |
return windows | |
# always return the same set of windows; | |
# no overlap, just cut up based on context_length; | |
# last window size will be different if num_frames % opts.context_length != 0 | |
for start_idx in range(0, num_frames, opts.context_length): | |
windows.append(list(range(start_idx, min(start_idx + opts.context_length, num_frames)))) | |
return windows | |
def create_windows_default(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): | |
return [list(range(num_frames))] | |
def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): | |
context_func = CONTEXT_MAPPING.get(opts.context_schedule, None) | |
if not context_func: | |
raise ValueError(f"Unknown context_schedule '{opts.context_schedule}'.") | |
return context_func(num_frames, opts) | |
CONTEXT_MAPPING = { | |
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped, | |
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard, | |
ContextSchedules.STATIC_STANDARD: create_windows_static_standard, | |
ContextSchedules.BATCHED: create_windows_batched, | |
ContextSchedules.VIEW_AS_CONTEXT: create_windows_default, # just return all to allow Views to do all the work | |
} | |
def get_context_weights(num_frames: int, fuse_method: str): | |
weights_func = FUSE_MAPPING.get(fuse_method, None) | |
if not weights_func: | |
raise ValueError(f"Unknown fuse_method '{fuse_method}'.") | |
return weights_func(num_frames) | |
def create_weights_flat(length: int, **kwargs) -> list[float]: | |
# weight is the same for all | |
return [1.0] * length | |
def create_weights_pyramid(length: int, **kwargs) -> list[float]: | |
# weight is based on the distance away from the edge of the context window; | |
# based on weighted average concept in FreeNoise paper | |
if length % 2 == 0: | |
max_weight = length // 2 | |
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) | |
else: | |
max_weight = (length + 1) // 2 | |
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) | |
return weight_sequence | |
FUSE_MAPPING = { | |
ContextFuseMethod.FLAT: create_weights_flat, | |
ContextFuseMethod.PYRAMID: create_weights_pyramid, | |
ContextFuseMethod.RELATIVE: create_weights_pyramid, | |
} | |
# Returns fraction that has denominator that is a power of 2 | |
def ordered_halving(val): | |
# get binary value, padded with 0s for 64 bits | |
bin_str = f"{val:064b}" | |
# flip binary value, padding included | |
bin_flip = bin_str[::-1] | |
# convert binary to int | |
as_int = int(bin_flip, 2) | |
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616, | |
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's) | |
return as_int / (1 << 64) | |
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: | |
all_indexes = list(range(num_frames)) | |
for w in windows: | |
for val in w: | |
try: | |
all_indexes.remove(val) | |
except ValueError: | |
pass | |
return all_indexes | |
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: | |
prev_val = -1 | |
for i, val in enumerate(window): | |
val = val % num_frames | |
if val < prev_val: | |
return True, i | |
prev_val = val | |
return False, -1 | |
def shift_window_to_start(window: list[int], num_frames: int): | |
start_val = window[0] | |
for i in range(len(window)): | |
# 1) subtract each element by start_val to move vals relative to the start of all frames | |
# 2) add num_frames and take modulus to get adjusted vals | |
window[i] = ((window[i] - start_val) + num_frames) % num_frames | |
def shift_window_to_end(window: list[int], num_frames: int): | |
# 1) shift window to start | |
shift_window_to_start(window, num_frames) | |
end_val = window[-1] | |
end_delta = num_frames - end_val - 1 | |
for i in range(len(window)): | |
# 2) add end_delta to each val to slide windows to end | |
window[i] = window[i] + end_delta | |