diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index fa15a9e6..df782304 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -50,6 +50,10 @@ class Merge(Expr): "shuffle_backend": None, } + def __init__(self, *args, _precomputed_meta=None, **kwargs): + super().__init__(*args, **kwargs) + self._precomputed_meta = _precomputed_meta + def __str__(self): return f"Merge({self._name[-7:]})" @@ -70,6 +74,8 @@ def kwargs(self): @functools.cached_property def _meta(self): + if self._precomputed_meta is not None: + return self._precomputed_meta left = meta_nonempty(self.left._meta) right = meta_nonempty(self.right._meta) return make_meta(left.merge(right, **self.kwargs)) @@ -105,7 +111,9 @@ def _lower(self): or right.npartitions == 1 and how in ("left", "inner") ): - return BlockwiseMerge(left, right, **self.kwargs) + return BlockwiseMerge( + left, right, **self.kwargs, _precomputed_meta=self._meta + ) # Check if we are merging on indices with known divisions merge_indexed_left = ( @@ -166,6 +174,7 @@ def _lower(self): indicator=self.indicator, left_index=left_index, right_index=right_index, + _precomputed_meta=self._meta, ) if shuffle_left_on: @@ -187,7 +196,7 @@ def _lower(self): ) # Blockwise merge - return BlockwiseMerge(left, right, **self.kwargs) + return BlockwiseMerge(left, right, **self.kwargs, _precomputed_meta=self._meta) def _simplify_up(self, parent): if isinstance(parent, (Projection, Index)): @@ -240,8 +249,13 @@ def _simplify_up(self, parent): if set(project_left) < set(left.columns) or set(project_right) < set( right.columns ): + columns = left_on + right_on + projection + meta_cols = [col for col in self.columns if col in columns] result = type(self)( - left[project_left], right[project_right], *self.operands[2:] + left[project_left], + right[project_right], + *self.operands[2:], + _precomputed_meta=self._meta[meta_cols], ) if parent_columns is None: return type(parent)(result) @@ -277,17 +291,7 @@ def _lower(self): @functools.cached_property def _meta(self): - left = self.left._meta.drop(columns=_HASH_COLUMN_NAME) - right = self.right._meta.drop(columns=_HASH_COLUMN_NAME) - return left.merge( - right, - left_on=self.left_on, - right_on=self.right_on, - indicator=self.indicator, - suffixes=self.suffixes, - left_index=self.left_index, - right_index=self.right_index, - ) + return self._precomputed_meta def _layer(self) -> dict: dsk = {} diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 20b15514..7d5a4f6b 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -557,7 +557,7 @@ def _dataset_info(self): return dataset_info - @property + @cached_property def _meta(self): meta = self._dataset_info["meta"] if self._series: