Skip to content

Commit

Permalink
Add triton_op test for user defined triton caching (pytorch#141407)
Browse files Browse the repository at this point in the history
Fix failing internal codecache test

Pull Request resolved: pytorch#141407
Approved by: https://github.com/aorenste
  • Loading branch information
oulgen authored and pytorchmergebot committed Nov 23, 2024
1 parent 8b4ae29 commit 3473dfa
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions test/inductor/test_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
from torch._library import capture_triton
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import largeTensorTest
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -586,6 +587,55 @@ def fn2(x, y):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)

@requires_gpu()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("bundle_triton", (False, True))
def test_triton_op(self, bundle_triton):
libname = "my_cool_namespace"
opname = "my_triton_operator"

@torch._library.triton_op(f"{libname}::{opname}", mutates_args={})
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()

def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output

def f(x, y):
return add(x, y)

with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(f, fullgraph=True)

x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)

result = compiled_fn(x, y)

self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)

# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()

# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)

result = compiled_fn(x, y)

self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)

@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_generated_kernel_count(self):
Expand Down

0 comments on commit 3473dfa

Please sign in to comment.