diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index d7363241f..5dc78c9b3 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -65,7 +65,7 @@ def _repr_html_(self): grid=grid, nbytes=nbytes, cbytes=cbytes, - arrs_in_plan=f"{self.plan.num_arrays()} arrays in Plan", + arrs_in_plan=f"{self.plan._finalize().num_arrays()} arrays in Plan", arrtype="np.ndarray", ) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 9dc9f22b7..e36503bc5 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -207,7 +207,9 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap """Compiles functions from all blockwise ops by mutating the input dag.""" # Recommended: make a copy of the dag before calling this function. - compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs + compile_with_config = ( + "config" in inspect.getfullargspec(compile_function).kwonlyargs + ) for n in dag.nodes: node = dag.nodes[n] @@ -219,7 +221,9 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap continue if compile_with_config: - compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config) + compiled = compile_function( + node["pipeline"].config.function, config=node["pipeline"].config + ) else: compiled = compile_function(node["pipeline"].config.function) @@ -227,23 +231,26 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap # maybe we should investigate some sort of optics library for frozen dataclasses... new_pipeline = dataclasses.replace( node["pipeline"], - config=dataclasses.replace(node["pipeline"].config, function=compiled) + config=dataclasses.replace(node["pipeline"].config, function=compiled), ) node["pipeline"] = new_pipeline return dag @lru_cache - def _finalize_dag( - self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None, - ) -> nx.MultiDiGraph: + def _finalize( + self, + optimize_graph: bool = True, + optimize_function=None, + compile_function: Optional[Decorator] = None, + ) -> "FinalizedPlan": dag = self.optimize(optimize_function).dag if optimize_graph else self.dag # create a copy since _create_lazy_zarr_arrays mutates the dag dag = dag.copy() if callable(compile_function): dag = self._compile_blockwise(dag, compile_function) dag = self._create_lazy_zarr_arrays(dag) - return nx.freeze(dag) + return FinalizedPlan(nx.freeze(dag)) def execute( self, @@ -256,7 +263,10 @@ def execute( spec=None, **kwargs, ): - dag = self._finalize_dag(optimize_graph, optimize_function, compile_function) + finalized_plan = self._finalize( + optimize_graph, optimize_function, compile_function + ) + dag = finalized_plan.dag compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}" @@ -275,43 +285,6 @@ def execute( event = ComputeEndEvent(compute_id, dag) [callback.on_compute_end(event) for callback in callbacks] - def num_tasks(self, optimize_graph=True, optimize_function=None, resume=None): - """Return the number of tasks needed to execute this plan.""" - dag = self._finalize_dag(optimize_graph, optimize_function) - tasks = 0 - for _, node in visit_nodes(dag, resume=resume): - tasks += node["primitive_op"].num_tasks - return tasks - - def num_arrays(self, optimize_graph: bool = True, optimize_function=None) -> int: - """Return the number of arrays in this plan.""" - dag = self._finalize_dag(optimize_graph, optimize_function) - return sum(d.get("type") == "array" for _, d in dag.nodes(data=True)) - - def max_projected_mem( - self, optimize_graph=True, optimize_function=None, resume=None - ): - """Return the maximum projected memory across all tasks to execute this plan.""" - dag = self._finalize_dag(optimize_graph, optimize_function) - projected_mem_values = [ - node["primitive_op"].projected_mem - for _, node in visit_nodes(dag, resume=resume) - ] - return max(projected_mem_values) if len(projected_mem_values) > 0 else 0 - - def total_nbytes_written( - self, optimize_graph: bool = True, optimize_function=None - ) -> int: - """Return the total number of bytes written for all materialized arrays in this plan.""" - dag = self._finalize_dag(optimize_graph, optimize_function) - nbytes = 0 - for _, d in dag.nodes(data=True): - if d.get("type") == "array": - target = d["target"] - if isinstance(target, LazyZarrArray): - nbytes += target.nbytes - return nbytes - def visualize( self, filename="cubed", @@ -321,7 +294,8 @@ def visualize( optimize_function=None, show_hidden=False, ): - dag = self._finalize_dag(optimize_graph, optimize_function) + finalized_plan = self._finalize(optimize_graph, optimize_function) + dag = finalized_plan.dag dag = dag.copy() # make a copy since we mutate the DAG below # remove edges from create-arrays output node to avoid cluttering the diagram @@ -336,9 +310,9 @@ def visualize( "rankdir": rankdir, "label": ( # note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/) - rf"num tasks: {self.num_tasks(optimize_graph, optimize_function)}\l" - rf"max projected memory: {memory_repr(self.max_projected_mem(optimize_graph, optimize_function))}\l" - rf"total nbytes written: {memory_repr(self.total_nbytes_written(optimize_graph, optimize_function))}\l" + rf"num tasks: {finalized_plan.num_tasks()}\l" + rf"max projected memory: {memory_repr(finalized_plan.max_projected_mem())}\l" + rf"total nbytes written: {memory_repr(finalized_plan.total_nbytes_written())}\l" rf"optimized: {optimize_graph}\l" ), "labelloc": "bottom", @@ -474,6 +448,49 @@ def visualize( return None +class FinalizedPlan: + """A plan that is ready to be run. + + Finalizing a plan involves the following steps: + 1. optimization (optional) + 2. adding housekeping nodes to create arrays + 3. compiling functions (optional) + 4. freezing the final DAG so it can't be changed + """ + + def __init__(self, dag): + self.dag = dag + + def max_projected_mem(self, resume=None): + """Return the maximum projected memory across all tasks to execute this plan.""" + projected_mem_values = [ + node["primitive_op"].projected_mem + for _, node in visit_nodes(self.dag, resume=resume) + ] + return max(projected_mem_values) if len(projected_mem_values) > 0 else 0 + + def num_arrays(self) -> int: + """Return the number of arrays in this plan.""" + return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True)) + + def num_tasks(self, resume=None): + """Return the number of tasks needed to execute this plan.""" + tasks = 0 + for _, node in visit_nodes(self.dag, resume=resume): + tasks += node["primitive_op"].num_tasks + return tasks + + def total_nbytes_written(self) -> int: + """Return the total number of bytes written for all materialized arrays in this plan.""" + nbytes = 0 + for _, d in self.dag.nodes(data=True): + if d.get("type") == "array": + target = d["target"] + if isinstance(target, LazyZarrArray): + nbytes += target.nbytes + return nbytes + + def arrays_to_dag(*arrays): from .array import check_array_specs diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index fdf90ca41..4e092e96d 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -373,13 +373,14 @@ def test_reduction_multiple_rounds(tmp_path, executor): a = xp.ones((100, 10), dtype=np.uint8, chunks=(1, 10), spec=spec) b = xp.sum(a, axis=0, dtype=np.uint8) # check that there is > 1 blockwise step (after optimization) + finalized_plan = b.plan._finalize() blockwises = [ n - for (n, d) in b.plan.dag.nodes(data=True) + for (n, d) in finalized_plan.dag.nodes(data=True) if d.get("op_name", None) == "blockwise" ] assert len(blockwises) > 1 - assert b.plan.max_projected_mem() <= 1000 + assert finalized_plan.max_projected_mem() <= 1000 assert_array_equal(b.compute(executor=executor), np.ones((100, 10)).sum(axis=0)) @@ -555,7 +556,7 @@ def test_plan_scaling(tmp_path, factor): ) c = xp.matmul(a, b) - assert c.plan.num_tasks() > 0 + assert c.plan._finalize().num_tasks() > 0 c.visualize(filename=tmp_path / "c") @@ -568,7 +569,7 @@ def test_plan_quad_means(tmp_path, t_length): uv = u * v m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True) - assert m.plan.num_tasks() > 0 + assert m.plan._finalize().num_tasks() > 0 m.visualize( filename=tmp_path / "quad_means_unoptimized", optimize_graph=False, diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 6a8dad86a..dfc4e1ac8 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -181,7 +181,7 @@ def test_resume(spec, executor): d = xp.negative(c) num_created_arrays = 2 # c, d - assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 8 + assert d.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 8 task_counter = TaskCounter() c.compute(executor=executor, callbacks=[task_counter], optimize_graph=False) @@ -321,13 +321,15 @@ def test_check_runtime_memory_processes(spec, executor): try: from numba import jit as numba_jit + COMPILE_FUNCTIONS.append(numba_jit) except ModuleNotFoundError: pass try: - if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''): + if "jax" in os.environ.get("CUBED_BACKEND_ARRAY_API_MODULE", ""): from jax import jit as jax_jit + COMPILE_FUNCTIONS.append(jax_jit) except ModuleNotFoundError: pass @@ -339,7 +341,8 @@ def test_check_compilation(spec, executor, compile_function): b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) c = xp.add(a, b) assert_array_equal( - c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]) + c.compute(executor=executor, compile_function=compile_function), + np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]), ) @@ -352,7 +355,7 @@ def compile_function(func): c = xp.add(a, b) with pytest.raises(NotImplementedError) as excinfo: c.compute(executor=executor, compile_function=compile_function) - + assert "add" in str(excinfo.value), "Compile function was applied to add operation." @@ -365,5 +368,7 @@ def compile_function(func, *, config=None): c = xp.add(a, b) with pytest.raises(NotImplementedError) as excinfo: c.compute(executor=executor, compile_function=compile_function) - - assert "BlockwiseSpec" in str(excinfo.value), "Compile function was applied with a config argument." + + assert "BlockwiseSpec" in str( + excinfo.value + ), "Compile function was applied with a config argument." diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index e2f77cc1f..7fee9a725 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -38,25 +38,17 @@ def test_fusion(spec, opt_fn): num_arrays = 4 # a, b, c, d num_created_arrays = 3 # b, c, d (a is not created on disk) - assert d.plan.num_arrays(optimize_graph=False) == num_arrays - assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 12 - assert ( - d.plan.total_nbytes_written(optimize_graph=False) - == b.nbytes + c.nbytes + d.nbytes - ) + plan_unopt = d.plan._finalize(optimize_graph=False) + assert plan_unopt.num_arrays() == num_arrays + assert plan_unopt.num_tasks() == num_created_arrays + 12 + assert plan_unopt.total_nbytes_written() == b.nbytes + c.nbytes + d.nbytes + num_arrays = 2 # a, d num_created_arrays = 1 # d (a is not created on disk) - assert ( - d.plan.num_arrays(optimize_graph=True, optimize_function=opt_fn) == num_arrays - ) - assert ( - d.plan.num_tasks(optimize_graph=True, optimize_function=opt_fn) - == num_created_arrays + 4 - ) - assert ( - d.plan.total_nbytes_written(optimize_graph=True, optimize_function=opt_fn) - == d.nbytes - ) + plan_opt = d.plan._finalize(optimize_graph=True, optimize_function=opt_fn) + assert plan_opt.num_arrays() == num_arrays + assert plan_opt.num_tasks() == num_created_arrays + 4 + assert plan_opt.total_nbytes_written() == d.nbytes task_counter = TaskCounter() result = d.compute(optimize_function=opt_fn, callbacks=[task_counter]) @@ -78,10 +70,10 @@ def test_fusion_transpose(spec, opt_fn): d = c.T num_created_arrays = 3 # b, c, d - assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 12 + assert d.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 12 num_created_arrays = 1 # d assert ( - d.plan.num_tasks(optimize_graph=True, optimize_function=opt_fn) + d.plan._finalize(optimize_graph=True, optimize_function=opt_fn).num_tasks() == num_created_arrays + 4 ) @@ -104,9 +96,9 @@ def test_fusion_map_direct(spec): c = xp.negative(b) # should be fused with b num_created_arrays = 2 # b, c - assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 4 + assert c.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 4 num_created_arrays = 1 # c - assert c.plan.num_tasks(optimize_graph=True) == num_created_arrays + 2 + assert c.plan._finalize(optimize_graph=True).num_tasks() == num_created_arrays + 2 task_counter = TaskCounter() result = c.compute(callbacks=[task_counter]) @@ -129,8 +121,10 @@ def test_no_fusion(spec): opt_fn = simple_optimize_dag num_created_arrays = 3 # b, c, d - assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 - assert d.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 3 + assert d.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 3 + assert ( + d.plan._finalize(optimize_function=opt_fn).num_tasks() == num_created_arrays + 3 + ) task_counter = TaskCounter() result = d.compute(optimize_function=opt_fn, callbacks=[task_counter]) @@ -151,8 +145,10 @@ def test_no_fusion_multiple_edges(spec): opt_fn = simple_optimize_dag num_created_arrays = 2 # c, d - assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 - assert d.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 2 + assert d.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 2 + assert ( + d.plan._finalize(optimize_function=opt_fn).num_tasks() == num_created_arrays + 2 + ) task_counter = TaskCounter() result = d.compute(optimize_function=opt_fn, callbacks=[task_counter]) @@ -167,18 +163,19 @@ def test_custom_optimize_function(spec): c = xp.astype(b, np.float32) d = xp.negative(c) - num_tasks_with_no_optimization = d.plan.num_tasks(optimize_graph=False) + num_tasks_with_no_optimization = d.plan._finalize(optimize_graph=False).num_tasks() - assert d.plan.num_tasks(optimize_graph=True) < num_tasks_with_no_optimization + assert ( + d.plan._finalize(optimize_graph=True).num_tasks() + < num_tasks_with_no_optimization + ) def custom_optimize_function(dag): # leave DAG unchanged return dag assert ( - d.plan.num_tasks( - optimize_graph=True, optimize_function=custom_optimize_function - ) + d.plan._finalize(optimize_function=custom_optimize_function).num_tasks() == num_tasks_with_no_optimization ) @@ -291,9 +288,11 @@ def test_fuse_unary_op(spec): assert get_num_input_blocks(optimized_dag, c.name) == (1,) num_created_arrays = 2 # b, c - assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 + assert c.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 2 num_created_arrays = 1 # c - assert c.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 1 + assert ( + c.plan._finalize(optimize_function=opt_fn).num_tasks() == num_created_arrays + 1 + ) task_counter = TaskCounter() result = c.compute(callbacks=[task_counter], optimize_function=opt_fn) @@ -332,9 +331,11 @@ def test_fuse_binary_op(spec): assert get_num_input_blocks(optimized_dag, e.name) == (1, 1) num_created_arrays = 3 # c, d, e - assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 + assert e.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 3 num_created_arrays = 1 # e - assert e.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 1 + assert ( + e.plan._finalize(optimize_function=opt_fn).num_tasks() == num_created_arrays + 1 + ) task_counter = TaskCounter() result = e.compute(callbacks=[task_counter], optimize_function=opt_fn)