Skip to content

Commit

Permalink
axes_factor now automatically calculated based on intended usage, rem…
Browse files Browse the repository at this point in the history
…oved axes_factor-related vars from the AnimateDiff Loader node
  • Loading branch information
Kosinkadink committed Sep 3, 2023
1 parent f1c9367 commit a8fbc6e
Showing 1 changed file with 28 additions and 47 deletions.
75 changes: 28 additions & 47 deletions animatediff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
MM_INJECTED_ATTR = "_mm_injected"

class InjectionParams:
def __init__(self, axes_factor: int, video_length: int, ignore_factor_on_trunc: bool, unlimited_area_hack: bool) -> None:
self.axes_factor = axes_factor
def __init__(self, video_length: int, unlimited_area_hack: bool) -> None:
self.video_length = video_length
self.ignore_factor_on_trunc = ignore_factor_on_trunc
self.unlimited_area_hack = unlimited_area_hack

def is_mm_injected_into_model(model: ModelPatcher):
Expand Down Expand Up @@ -78,13 +76,8 @@ def unlimited_batch_area():

def groupnorm_mm_factory(params: InjectionParams):
def groupnorm_mm_forward(self, input: Tensor) -> Tensor:
# if video_length same length as input tensor and ignore_on_trunc, use 1
if input.size()[0] == params.video_length and params.ignore_factor_on_trunc:
axes_factor = 1
# otherwise, use set b_factor_inp if can divide input size
else:
axes_factor = params.axes_factor if input.size()[0]%params.axes_factor==0 else 1

axes_factor = input.size(0)//params.video_length

input = rearrange(input, "(b f) c h w -> b c f h w", b=axes_factor)
input = group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
input = rearrange(input, "b c f h w -> (b f) c h w", b=axes_factor)
Expand Down Expand Up @@ -241,16 +234,9 @@ def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"latents": ("LATENT",),
"model_name": (get_available_models(),),
"width": ("INT", {"default": 512, "min": 64, "max": 1024, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": 1024, "step": 8}),
"frame_number": (
"INT",
{"default": 16, "min": 2, "max": 24, "step": 1},
),
},
"optional": {
"init_latent": ("LATENT",),
"unlimited_area_hack": ([False, True],),
},
}

Expand All @@ -264,28 +250,33 @@ def IS_CHANGED(s, model: ModelPatcher):
FUNCTION = "inject_motion_modules"

def inject_motion_modules(
self,
model: ModelPatcher,
model_name: str,
width: int,
height: int,
frame_number=16,
init_latent: Dict[str, torch.Tensor] = None,
):
model = model.clone()

self,
model: ModelPatcher,
latents: Dict[str, torch.Tensor],
model_name: str, unlimited_area_hack: bool
):
if model_name not in motion_modules:
motion_modules[model_name] = load_motion_module(model_name)

motion_module = motion_modules[model_name]
# check that latents don't exceed max frame size
init_frames_len = len(latents["samples"])
if init_frames_len > 24:
# TODO: warning and cutoff frames instead of error
raise ValueError(f"AnimateDiff has upper limit of 24 frames, but received {init_frames_len} latents.")
# set motion_module's video_length to match latent length
motion_module.set_video_length(frame_number)
motion_module.set_video_length(init_frames_len)

model = model.clone()
unet = model.model.diffusion_model
unet_hash = calculate_model_hash(unet)

need_inject = unet_hash not in injected_model_hashs

injection_params = InjectionParams(
video_length=init_frames_len,
unlimited_area_hack=unlimited_area_hack,
)

if unet_hash in injected_model_hashs:
(mm_type, version) = injected_model_hashs[unet_hash]
if version != self.version or mm_type != motion_module.mm_type:
Expand All @@ -294,23 +285,17 @@ def inject_motion_modules(
ejectors[version](unet)
need_inject = True
else:
logger.info(f"Motion module already injected, skipping injection.")
logger.info(f"Motion module already injected, only injecting params.")
set_mm_injected_params(model, injection_params)

if need_inject:
logger.info(f"Injecting motion module {model_name} version {self.version}.")
injectors[self.version](unet, motion_module)

injectors[self.version](unet, motion_module, injection_params)
unet_hash = calculate_model_hash(unet)
injected_model_hashs[unet_hash] = (motion_module.mm_type, self.version)

if init_latent is None:
latent = torch.zeros([frame_number, 4, height // 8, width // 8]).cpu()
else:
# clone value of first frame
latent = init_latent["samples"][:1, :, :, :].clone().cpu()
# repeat for all frames
latent = latent.repeat(frame_number, 1, 1, 1)

return (model, {"samples": latent})
return (model, latents)


class AnimateDiffLoader:
Expand All @@ -324,8 +309,6 @@ def INPUT_TYPES(s):
"model": ("MODEL",),
"latents": ("LATENT",),
"model_name": (get_available_models(),),
"axes_factor": ("INT", {"default": 2, "min": 1, "max": 24, "step": 1},),
"ignore_factor_on_trunc": ([True, False],),
"unlimited_area_hack": ([False, True],),
},
}
Expand All @@ -343,7 +326,7 @@ def inject_motion_modules(
self,
model: ModelPatcher,
latents: Dict[str, torch.Tensor],
model_name: str, axes_factor: int, ignore_factor_on_trunc: bool, unlimited_area_hack: bool
model_name: str, unlimited_area_hack: bool
):
if model_name not in motion_modules:
motion_modules[model_name] = load_motion_module(model_name)
Expand All @@ -363,9 +346,7 @@ def inject_motion_modules(
need_inject = unet_hash not in injected_model_hashs

injection_params = InjectionParams(
axes_factor=axes_factor,
video_length=init_frames_len,
ignore_factor_on_trunc=ignore_factor_on_trunc,
unlimited_area_hack=unlimited_area_hack,
)

Expand Down

0 comments on commit a8fbc6e

Please sign in to comment.