diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index ba596dd0..d9154e00 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -583,24 +583,61 @@ def get_default_shuffle_method() -> str: def patch_shuffle_expression() -> None: """Patch Dasks Shuffle expression. - This changes ``Shuffle._lower`` to apply explicit-comms - shuffling when the 'explicit-comms' config is enabled. + Notice, this is monkey patched into Dask at dask_cuda + import, and it changes `Shuffle._lower` to wrap the + original shuffle expression in `ExplicitCommsShuffle`. """ from dask_expr._collection import new_collection + from dask_expr._expr import Expr from dask_expr._shuffle import Shuffle as DXShuffle + class ExplicitCommsShuffle(Expr): + """Explicit Comms Shuffle.""" + + _parameters = ["wrapped"] + + @property + def original(self): + assert len(self.wrapped) == 1, f"Unexpected parameters: {self.wrapped[1:]}" + return self.wrapped[0] + + @property + def _meta(self): + return self.original.frame._meta + + def _lower(self): + return None + + def _divisions(self): + return (None,) * (self.original.frame.npartitions + 1) + + def _layer(self): + if not hasattr(self, "_shuffle_cache"): + self._shuffle_cache = {} + try: + expr = self._shuffle_cache[self._name] + except KeyError: + on = self.original.partitioning_index + expr = shuffle( + new_collection(self.original.frame), + [on] if isinstance(on, str) else on, + self.original.npartitions_out, + self.original.ignore_index, + ) + self._shuffle_cache[self._name] = expr + graph = expr.dask.copy() + graph.update( + {(self._name, i): (expr._name, i) for i in range(self.npartitions)} + ) + return graph + _base_lower = DXShuffle._lower def _lower(self): if self.method in ("tasks", None) and _use_explicit_comms(): - on = self.partitioning_index - on = [on] if isinstance(on, str) else on - return shuffle( - new_collection(self.frame), - on, - self.npartitions_out, - self.ignore_index, - ).expr + # Wrap the original Shuffle in an ExplicitCommsShuffle + # (Use list argument to encapsulate dependencies) + return ExplicitCommsShuffle([self]) else: # Use upstream lowering logic return _base_lower(self)