daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
from typing import Callable, Optional, Union
import numpy as np
from torch import Tensor
from comfy.model_base import BaseModel
from .utils_motion import get_sorted_list_via_attr
class ContextFuseMethod:
FLAT = "flat"
PYRAMID = "pyramid"
RELATIVE = "relative"
LIST = [PYRAMID, FLAT]
LIST_STATIC = [PYRAMID, RELATIVE, FLAT]
class ContextType:
UNIFORM_WINDOW = "uniform window"
class ContextOptions:
def __init__(self, context_length: int=None, context_stride: int=None, context_overlap: int=None,
context_schedule: str=None, closed_loop: bool=False, fuse_method: str=ContextFuseMethod.FLAT,
use_on_equal_length: bool=False, view_options: 'ContextOptions'=None,
start_percent=0.0, guarantee_steps=1):
# permanent settings
self.context_length = context_length
self.context_stride = context_stride
self.context_overlap = context_overlap
self.context_schedule = context_schedule
self.closed_loop = closed_loop
self.fuse_method = fuse_method
self.sync_context_to_pe = False # this feature is likely bad and stay unused, so I might remove this
self.use_on_equal_length = use_on_equal_length
self.view_options = view_options.clone() if view_options else view_options
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
# temporary vars
self._step: int = 0
@property
def step(self):
return self._step
@step.setter
def step(self, value: int):
self._step = value
if self.view_options:
self.view_options.step = value
def clone(self):
n = ContextOptions(context_length=self.context_length, context_stride=self.context_stride,
context_overlap=self.context_overlap, context_schedule=self.context_schedule,
closed_loop=self.closed_loop, fuse_method=self.fuse_method,
use_on_equal_length=self.use_on_equal_length, view_options=self.view_options,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
n.start_t = self.start_t
return n
class ContextOptionsGroup:
def __init__(self):
self.contexts: list[ContextOptions] = []
self._current_context: ContextOptions = None
self._current_used_steps: int = 0
self._current_index: int = 0
self.step = 0
def reset(self):
self._current_context = None
self._current_used_steps = 0
self._current_index = 0
self.step = 0
self._set_first_as_current()
@classmethod
def default(cls):
def_context = ContextOptions()
new_group = ContextOptionsGroup()
new_group.add(def_context)
return new_group
def add(self, context: ContextOptions):
# add to end of list, then sort
self.contexts.append(context)
self.contexts = get_sorted_list_via_attr(self.contexts, "start_percent")
self._set_first_as_current()
def add_to_start(self, context: ContextOptions):
# add to start of list, then sort
self.contexts.insert(0, context)
self.contexts = get_sorted_list_via_attr(self.contexts, "start_percent")
self._set_first_as_current()
def has_index(self, index: int) -> int:
return index >=0 and index < len(self.contexts)
def is_empty(self) -> bool:
return len(self.contexts) == 0
def clone(self):
cloned = ContextOptionsGroup()
for context in self.contexts:
cloned.contexts.append(context)
cloned._set_first_as_current()
return cloned
def initialize_timesteps(self, model: BaseModel):
for context in self.contexts:
context.start_t = model.model_sampling.percent_to_sigma(context.start_percent)
def prepare_current_context(self, t: Tensor):
curr_t: float = t[0]
prev_index = self._current_index
# if met guaranteed steps, look for next context in case need to switch
if self._current_used_steps >= self._current_context.guarantee_steps:
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.contexts)):
eval_c = self.contexts[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_context = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_context.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else:
break
# update steps current context is used
self._current_used_steps += 1
def _set_first_as_current(self):
if len(self.contexts) > 0:
self._current_context = self.contexts[0]
# properties shadow those of ContextOptions
@property
def context_length(self):
return self._current_context.context_length
@property
def context_overlap(self):
return self._current_context.context_overlap
@property
def context_stride(self):
return self._current_context.context_stride
@property
def context_schedule(self):
return self._current_context.context_schedule
@property
def closed_loop(self):
return self._current_context.closed_loop
@property
def fuse_method(self):
return self._current_context.fuse_method
@property
def use_on_equal_length(self):
return self._current_context.use_on_equal_length
@property
def view_options(self):
return self._current_context.view_options
class ContextSchedules:
UNIFORM_LOOPED = "looped_uniform"
UNIFORM_STANDARD = "standard_uniform"
STATIC_STANDARD = "standard_static"
BATCHED = "batched"
VIEW_AS_CONTEXT = "view_as_context"
LEGACY_UNIFORM_LOOPED = "uniform"
LEGACY_UNIFORM_SCHEDULE_LIST = [LEGACY_UNIFORM_LOOPED]
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
def create_windows_uniform_looped(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]):
windows = []
if num_frames < opts.context_length:
windows.append(list(range(num_frames)))
return windows
context_stride = min(opts.context_stride, int(np.ceil(np.log2(num_frames / opts.context_length))) + 1)
# obtain uniform windows as normal, looping and all
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(opts.step)))
for j in range(
int(ordered_halving(opts.step) * context_step) + pad,
num_frames + pad + (0 if opts.closed_loop else -opts.context_overlap),
(opts.context_length * context_step - opts.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + opts.context_length * context_step, context_step)])
return windows
def create_windows_uniform_standard(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]):
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
# instead, they get shifted to the corresponding end of the frames.
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
windows = []
if num_frames <= opts.context_length:
windows.append(list(range(num_frames)))
return windows
context_stride = min(opts.context_stride, int(np.ceil(np.log2(num_frames / opts.context_length))) + 1)
# first, obtain uniform windows as normal, looping and all
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(opts.step)))
for j in range(
int(ordered_halving(opts.step) * context_step) + pad,
num_frames + pad + (-opts.context_overlap),
(opts.context_length * context_step - opts.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + opts.context_length * context_step, context_step)])
# now that windows are created, shift any windows that loop, and delete duplicate windows
delete_idxs = []
win_i = 0
while win_i < len(windows):
# if window is rolls over itself, need to shift it
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
if is_roll:
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + opts.context_length)))
# delete window if it's not unique
for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]:
delete_idxs.append(win_i)
break
win_i += 1
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
delete_idxs.reverse()
for i in delete_idxs:
windows.pop(i)
return windows
def create_windows_static_standard(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]):
windows = []
if num_frames <= opts.context_length:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows
delta = opts.context_length - opts.context_overlap
for start_idx in range(0, num_frames, delta):
# if past the end of frames, move start_idx back to allow same context_length
ending = start_idx + opts.context_length
if ending >= num_frames:
final_delta = ending - num_frames
final_start_idx = start_idx - final_delta
windows.append(list(range(final_start_idx, final_start_idx + opts.context_length)))
break
windows.append(list(range(start_idx, start_idx + opts.context_length)))
return windows
def create_windows_batched(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]):
windows = []
if num_frames <= opts.context_length:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows;
# no overlap, just cut up based on context_length;
# last window size will be different if num_frames % opts.context_length != 0
for start_idx in range(0, num_frames, opts.context_length):
windows.append(list(range(start_idx, min(start_idx + opts.context_length, num_frames))))
return windows
def create_windows_default(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]):
return [list(range(num_frames))]
def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]):
context_func = CONTEXT_MAPPING.get(opts.context_schedule, None)
if not context_func:
raise ValueError(f"Unknown context_schedule '{opts.context_schedule}'.")
return context_func(num_frames, opts)
CONTEXT_MAPPING = {
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
ContextSchedules.BATCHED: create_windows_batched,
ContextSchedules.VIEW_AS_CONTEXT: create_windows_default, # just return all to allow Views to do all the work
}
def get_context_weights(num_frames: int, fuse_method: str):
weights_func = FUSE_MAPPING.get(fuse_method, None)
if not weights_func:
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
return weights_func(num_frames)
def create_weights_flat(length: int, **kwargs) -> list[float]:
# weight is the same for all
return [1.0] * length
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
# weight is based on the distance away from the edge of the context window;
# based on weighted average concept in FreeNoise paper
if length % 2 == 0:
max_weight = length // 2
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
else:
max_weight = (length + 1) // 2
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
FUSE_MAPPING = {
ContextFuseMethod.FLAT: create_weights_flat,
ContextFuseMethod.PYRAMID: create_weights_pyramid,
ContextFuseMethod.RELATIVE: create_weights_pyramid,
}
# Returns fraction that has denominator that is a power of 2
def ordered_halving(val):
# get binary value, padded with 0s for 64 bits
bin_str = f"{val:064b}"
# flip binary value, padding included
bin_flip = bin_str[::-1]
# convert binary to int
as_int = int(bin_flip, 2)
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
return as_int / (1 << 64)
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
all_indexes = list(range(num_frames))
for w in windows:
for val in w:
try:
all_indexes.remove(val)
except ValueError:
pass
return all_indexes
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
prev_val = -1
for i, val in enumerate(window):
val = val % num_frames
if val < prev_val:
return True, i
prev_val = val
return False, -1
def shift_window_to_start(window: list[int], num_frames: int):
start_val = window[0]
for i in range(len(window)):
# 1) subtract each element by start_val to move vals relative to the start of all frames
# 2) add num_frames and take modulus to get adjusted vals
window[i] = ((window[i] - start_val) + num_frames) % num_frames
def shift_window_to_end(window: list[int], num_frames: int):
# 1) shift window to start
shift_window_to_start(window, num_frames)
end_val = window[-1]
end_delta = num_frames - end_val - 1
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta