diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index c5f0afa118f87..1859ca391e02a 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -463,6 +463,27 @@ def fn(a, b, c): self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm + @fresh_inductor_cache() + @config.patch(max_autotune=True, max_fusion_size=2) + def test_jit_fusion_matches_aot_fusion(self): + # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due + # to proximity, we want to make sure AOT-compile pass does the same. + # AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end + # of the V.graph.buffers list because fuse(buf2, buf4) would have a + # better proximity score than fuse(buf1, buf2). This scenario is possible + # since finalizing MultiTemplateBuffers needs to replace buffers. + def fn(x, number): + buf0 = x + x + buf1 = number.item() + buf2 = x * x + buf3 = x @ x # MultiTemplateBuffer + buf4 = x**2 + return buf0, buf1, buf2, buf3, buf4 + + inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda")) + torch._export.aot_compile(fn, args=inputs) + @config.patch(autotune_local_cache=False, autotune_remote_cache=False) def test_precompilations(self): def fn(a, b, c): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 456e0c50567d5..ec4763160a7b6 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1752,7 +1752,9 @@ def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer): del V.graph.name_to_buffer[replaced_name] new_node.name = orig_name - V.graph.buffers.remove(orig_node) + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node V.graph.name_to_buffer[orig_name] = new_node for i, node in enumerate(self.nodes):