diff --git a/tests/onnx_friendly_tome.py b/tests/onnx_friendly_tome.py new file mode 100644 index 0000000..218ae58 --- /dev/null +++ b/tests/onnx_friendly_tome.py @@ -0,0 +1,21 @@ +import unittest + +import torch + +from tomesd.merge import bipartite_soft_matching_random2d + + +class TestOnnxFriendlyToMeOperations(unittest.TestCase): + def test_component_correctness(self): + c = 320 + w = h = 64 + r = 0.2 + sx = sy = 2 + x = torch.rand(2, w * h, c) + m_orig, u_orig = bipartite_soft_matching_random2d(x, w, h, sx, sy, int(w * h * r), no_rand=True) + m_onnx, u_onnx = bipartite_soft_matching_random2d(x, w, h, sx, sy, int(w * h * r), no_rand=True, onnx_friendly=True) + torch.testing.assert_close(u_orig(m_orig(x)), u_onnx(m_onnx(x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tomesd/merge.py b/tomesd/merge.py index 6e8a513..64993ba 100644 --- a/tomesd/merge.py +++ b/tomesd/merge.py @@ -2,7 +2,7 @@ from typing import Tuple, Callable -def do_nothing(x: torch.Tensor, mode:str=None): +def do_nothing(x: torch.Tensor, mode: str = None): return x @@ -20,7 +20,8 @@ def mps_gather_workaround(input, dim, index): def bipartite_soft_matching_random2d(metric: torch.Tensor, w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False, - generator: torch.Generator = None) -> Tuple[Callable, Callable]: + generator: torch.Generator = None, + onnx_friendly: bool = False) -> Tuple[Callable, Callable]: """ Partitions the tokens into src and dst and merges r tokens from src to dst. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. @@ -34,6 +35,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, - r: number of tokens to remove (by merging) - no_rand: if true, disable randomness (use top left corner only) - rand_seed: if no_rand is false, and if not None, sets random seed. + - onnx_friendly: if onnx_friendly is True it replaces `torch.scatter_reduce` with onnx friendly operators: scatter and bincount """ B, N, _ = metric.shape @@ -41,7 +43,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather - + with torch.no_grad(): hsy, wsx = h // sy, w // sx @@ -49,10 +51,10 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, if no_rand: rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) else: - rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) - + rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead - idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64) idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) @@ -71,8 +73,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, # rand_idx is currently dst|src, so split them num_dst = hsy * wsx - a_idx = rand_idx[:, num_dst:, :] # src - b_idx = rand_idx[:, :num_dst, :] # dst + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst def split(x): C = x.shape[-1] @@ -99,10 +101,21 @@ def split(x): def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) n, t1, c = src.shape - + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) - dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + if not onnx_friendly: + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + else: + if mode not in ("mean", "sum"): + raise NotImplementedError(f"ONNX friendly currently supports 'mean' and 'sum' modes, got {mode}") + dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add') + if mode == "mean": + counts = torch.stack([ + torch.bincount(dst_idx[i, :, 0], minlength=dst.size(-2)) + for i in range(dst_idx.size(0)) + ], dim=0) + 1 + dst = dst / counts[..., None] return torch.cat([unm, dst], dim=1) diff --git a/tomesd/patch.py b/tomesd/patch.py index c4d42ef..e5d45c1 100644 --- a/tomesd/patch.py +++ b/tomesd/patch.py @@ -6,7 +6,6 @@ from .utils import isinstance_str, init_generator - def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]: original_h, original_w = tome_info["size"] original_tokens = original_h * original_w @@ -24,27 +23,22 @@ def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, args["generator"] = init_generator(x.device) elif args["generator"].device != x.device: args["generator"] = init_generator(x.device, fallback=args["generator"]) - + # If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same # batch, which causes artifacts with use_rand, so force it to be off. use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"] - m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, - no_rand=not use_rand, generator=args["generator"]) + m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, + no_rand=not use_rand, generator=args["generator"], onnx_friendly=tome_info["onnx_friendly"]) else: m, u = (merge.do_nothing, merge.do_nothing) - m_a, u_a = (m, u) if args["merge_attn"] else (merge.do_nothing, merge.do_nothing) + m_a, u_a = (m, u) if args["merge_attn"] else (merge.do_nothing, merge.do_nothing) m_c, u_c = (m, u) if args["merge_crossattn"] else (merge.do_nothing, merge.do_nothing) - m_m, u_m = (m, u) if args["merge_mlp"] else (merge.do_nothing, merge.do_nothing) + m_m, u_m = (m, u) if args["merge_mlp"] else (merge.do_nothing, merge.do_nothing) return m_a, m_c, m_m, u_a, u_c, u_m # Okay this is probably not very good - - - - - def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: """ Make a patched class on the fly so we don't have to import any specific modules. @@ -64,12 +58,8 @@ def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tenso x = u_m(self.ff(m_m(self.norm3(x)))) + x return x - - return ToMeBlock - - - + return ToMeBlock def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: @@ -77,19 +67,20 @@ def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch. Make a patched class for a diffusers model. This patch applies ToMe to the forward function of the block. """ + class ToMeBlock(block_class): # Save for unpatching later _parent = block_class def forward( - self, - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - cross_attention_kwargs=None, - class_labels=None, + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, ) -> torch.Tensor: # (1) ToMe m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(hidden_states, self._tome_info) @@ -139,7 +130,7 @@ def forward( # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) - + if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] @@ -159,12 +150,9 @@ def forward( return ToMeBlock - - - - def hook_tome_model(model: torch.nn.Module): """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ + def hook(module, args): module._tome_info["size"] = (args[0].shape[2], args[0].shape[3]) return None @@ -172,12 +160,6 @@ def hook(module, args): model._tome_info["hooks"].append(model.register_forward_pre_hook(hook)) - - - - - - def apply_patch( model: torch.nn.Module, ratio: float = 0.5, @@ -186,7 +168,8 @@ def apply_patch( use_rand: bool = True, merge_attn: bool = True, merge_crossattn: bool = False, - merge_mlp: bool = False): + merge_mlp: bool = False, + onnx_friendly: bool = False): """ Patches a stable diffusion model with ToMe. Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model). @@ -208,6 +191,7 @@ def apply_patch( - merge_attn: Whether or not to merge tokens for attention (recommended). - merge_crossattn: Whether or not to merge tokens for cross attention (not recommended). - merge_mlp: Whether or not to merge tokens for the mlp layers (very not recommended). + - onnx_friendly: Whether or not to replace scatter_reduce with onnx friendly ops. """ # Make sure the module is not currently patched @@ -235,8 +219,9 @@ def apply_patch( "generator": None, "merge_attn": merge_attn, "merge_crossattn": merge_crossattn, - "merge_mlp": merge_mlp - } + "merge_mlp": merge_mlp, + }, + "onnx_friendly": onnx_friendly } hook_tome_model(diffusion_model) @@ -259,9 +244,6 @@ def apply_patch( return model - - - def remove_patch(model: torch.nn.Module): """ Removes a patch from a ToMe Diffusion module if it was already patched. """ # For diffusers @@ -275,5 +257,5 @@ def remove_patch(model: torch.nn.Module): if module.__class__.__name__ == "ToMeBlock": module.__class__ = module._parent - + return model