Skip to content

Commit

Permalink
unit/runtime/*.py fix for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Dec 5, 2023
1 parent 2501812 commit c2552c9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_nested1_change():


def write_and_load_module(code, num_extra_lines):
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f:
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as f:
f.write(('# extra line\n' * num_extra_lines) + code)
f.flush()
spec = importlib.util.spec_from_file_location("module.name", f.name)
Expand Down
8 changes: 6 additions & 2 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ def kernel_sub(a, b, o, N: tl.constexpr):


def test_compile_in_subproc() -> None:
import os
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ())

multiprocessing.set_start_method('fork')
if os.name == "nt":
multiprocessing.set_start_method('spawn')
else:
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
proc.start()
proc.join()
Expand All @@ -64,7 +68,7 @@ def test_compile_in_forked_subproc() -> None:
capability = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ())

assert multiprocessing.get_start_method() == 'fork'
assert multiprocessing.get_start_method() in ['fork', 'spawn']
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability))
proc.start()
proc.join()
Expand Down

0 comments on commit c2552c9

Please sign in to comment.