From 782645a6e157765387d9486ca35d394b29339ef8 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:05:43 +0200 Subject: [PATCH 1/5] Merge projection selects too many columns --- dask_expr/_merge.py | 9 ++++++++- dask_expr/tests/test_merge.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index cacc25c8..72574c83 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -203,7 +203,14 @@ def _simplify_up(self, parent): projection = [projection] left, right = self.left, self.right - left_on, right_on = self.left_on, self.right_on + if isinstance(self.left_on, list): + left_on = self.left_on + else: + left_on = [self.left_on] if self.left_on is not None else [] + if isinstance(self.right_on, list): + right_on = self.right_on + else: + right_on = [self.right_on] if self.right_on is not None else [] left_suffix, right_suffix = self.suffixes[0], self.suffixes[1] project_left, project_right = [], [] diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 96365473..dacb92b3 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -164,3 +164,15 @@ def test_merge_len(): query = df.merge(df2).index.optimize(fuse=False) expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False) assert query._name == expected._name + + +def test_merge_optimize_subset_strings(): + pdf = lib.DataFrame({"a": [1, 2], "aaa": 1}) + pdf2 = lib.DataFrame({"b": [1, 2], "aaa": 1}) + df = from_pandas(pdf) + df2 = from_pandas(pdf2) + + query = df.merge(df2, on="aaa")[["aaa"]].optimize(fuse=False) + exp = df[["aaa"]].merge(df2[["aaa"]], on="aaa").optimize(fuse=False) + assert query._name == exp._name + assert_eq(query, pdf.merge(pdf2, on="aaa")[["aaa"]]) From 7a7802de8c8ad4d9b629a2a9882c7886b5d03a9a Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:03:20 +0200 Subject: [PATCH 2/5] Make meta calculcation for merge more efficient --- dask_expr/_merge.py | 43 +++++++++++++++++++++-------------- dask_expr/io/parquet.py | 2 +- dask_expr/tests/test_merge.py | 2 +- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index ea490d4f..69800ae6 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -37,6 +37,7 @@ class Merge(Expr): "suffixes", "indicator", "shuffle_backend", + "_meta", ] _defaults = { "how": "inner", @@ -47,6 +48,7 @@ class Merge(Expr): "suffixes": ("_x", "_y"), "indicator": False, "shuffle_backend": None, + "_meta": None, } def __str__(self): @@ -69,6 +71,8 @@ def kwargs(self): @functools.cached_property def _meta(self): + if self.operand("_meta") is not None: + return self.operand("_meta") left = meta_nonempty(self.left._meta) right = meta_nonempty(self.right._meta) return make_meta(left.merge(right, **self.kwargs)) @@ -104,7 +108,7 @@ def _lower(self): or right.npartitions == 1 and how in ("left", "inner") ): - return BlockwiseMerge(left, right, **self.kwargs) + return BlockwiseMerge(left, right, **self.kwargs, _meta=self._meta) # Check if we are merging on indices with known divisions merge_indexed_left = ( @@ -165,6 +169,7 @@ def _lower(self): indicator=self.indicator, left_index=left_index, right_index=right_index, + _meta=self._meta, ) if shuffle_left_on: @@ -186,7 +191,7 @@ def _lower(self): ) # Blockwise merge - return BlockwiseMerge(left, right, **self.kwargs) + return BlockwiseMerge(left, right, **self.kwargs, _meta=self._meta) def _simplify_up(self, parent): if isinstance(parent, (Projection, Index)): @@ -203,13 +208,20 @@ def _simplify_up(self, parent): projection = [projection] left, right = self.left, self.right - left_on, right_on = self.left_on, self.right_on + if isinstance(self.left_on, list): + left_on = self.left_on + else: + left_on = [self.left_on] if self.left_on is not None else [] + if isinstance(self.right_on, list): + right_on = self.right_on + else: + right_on = [self.right_on] if self.right_on is not None else [] left_suffix, right_suffix = self.suffixes[0], self.suffixes[1] project_left, project_right = [], [] # Find columns to project on the left for col in left.columns: - if left_on is not None and col in left_on or col in projection: + if col in left_on or col in projection: project_left.append(col) elif f"{col}{left_suffix}" in projection: project_left.append(col) @@ -220,7 +232,7 @@ def _simplify_up(self, parent): # Find columns to project on the right for col in right.columns: - if right_on is not None and col in right_on or col in projection: + if col in right_on or col in projection: project_right.append(col) elif f"{col}{right_suffix}" in projection: project_right.append(col) @@ -232,8 +244,13 @@ def _simplify_up(self, parent): if set(project_left) < set(left.columns) or set(project_right) < set( right.columns ): + columns = left_on + right_on + projection + meta_cols = [col for col in self.columns if col in columns] result = type(self)( - left[project_left], right[project_right], *self.operands[2:] + left[project_left], + right[project_right], + *self.operands[2:-1], + _meta=self._meta[meta_cols], ) if parent_columns is None: return type(parent)(result) @@ -252,6 +269,7 @@ class HashJoinP2P(Merge, PartitionsFiltered): "suffixes", "indicator", "_partitions", + "_meta", ] _defaults = { "how": "inner", @@ -262,6 +280,7 @@ class HashJoinP2P(Merge, PartitionsFiltered): "suffixes": ("_x", "_y"), "indicator": False, "_partitions": None, + "_meta": None, } def _lower(self): @@ -269,17 +288,7 @@ def _lower(self): @functools.cached_property def _meta(self): - left = self.left._meta.drop(columns=_HASH_COLUMN_NAME) - right = self.right._meta.drop(columns=_HASH_COLUMN_NAME) - return left.merge( - right, - left_on=self.left_on, - right_on=self.right_on, - indicator=self.indicator, - suffixes=self.suffixes, - left_index=self.left_index, - right_index=self.right_index, - ) + return self.operand("_meta") def _layer(self) -> dict: dsk = {} diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 20b15514..7d5a4f6b 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -557,7 +557,7 @@ def _dataset_info(self): return dataset_info - @property + @cached_property def _meta(self): meta = self._dataset_info["meta"] if self._series: diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 96365473..14f1445f 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -160,7 +160,7 @@ def test_merge_len(): pdf2 = lib.DataFrame({"x": [1, 2, 3], "z": 1}) df2 = from_pandas(pdf2, npartitions=2) - assert_eq(len(df.merge(df2)), len(pdf.merge(pdf2))) + # assert_eq(len(df.merge(df2)), len(pdf.merge(pdf2))) query = df.merge(df2).index.optimize(fuse=False) expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False) assert query._name == expected._name From a1d7e07535c49da42dcad97df4d9707eef4d5759 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:40:50 +0200 Subject: [PATCH 3/5] Update --- dask_expr/tests/test_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 409cc024..dacb92b3 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -160,7 +160,7 @@ def test_merge_len(): pdf2 = lib.DataFrame({"x": [1, 2, 3], "z": 1}) df2 = from_pandas(pdf2, npartitions=2) - # assert_eq(len(df.merge(df2)), len(pdf.merge(pdf2))) + assert_eq(len(df.merge(df2)), len(pdf.merge(pdf2))) query = df.merge(df2).index.optimize(fuse=False) expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False) assert query._name == expected._name From ce155344a068fd9e73da30f20ede97ee50272e64 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 30 Aug 2023 15:12:54 +0200 Subject: [PATCH 4/5] Add custom constructor --- dask_expr/_merge.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 69800ae6..af1dad04 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -37,7 +37,6 @@ class Merge(Expr): "suffixes", "indicator", "shuffle_backend", - "_meta", ] _defaults = { "how": "inner", @@ -48,9 +47,12 @@ class Merge(Expr): "suffixes": ("_x", "_y"), "indicator": False, "shuffle_backend": None, - "_meta": None, } + def __init__(self, *args, _precomputed_meta=None, **kwargs): + super().__init__(*args, **kwargs) + self._precomputed_meta = _precomputed_meta + def __str__(self): return f"Merge({self._name[-7:]})" @@ -71,8 +73,8 @@ def kwargs(self): @functools.cached_property def _meta(self): - if self.operand("_meta") is not None: - return self.operand("_meta") + if self._precomputed_meta is not None: + return self._precomputed_meta left = meta_nonempty(self.left._meta) right = meta_nonempty(self.right._meta) return make_meta(left.merge(right, **self.kwargs)) @@ -108,7 +110,9 @@ def _lower(self): or right.npartitions == 1 and how in ("left", "inner") ): - return BlockwiseMerge(left, right, **self.kwargs, _meta=self._meta) + return BlockwiseMerge( + left, right, **self.kwargs, _precomputed_meta=self._meta + ) # Check if we are merging on indices with known divisions merge_indexed_left = ( @@ -169,7 +173,7 @@ def _lower(self): indicator=self.indicator, left_index=left_index, right_index=right_index, - _meta=self._meta, + _precomputed_meta=self._meta, ) if shuffle_left_on: @@ -191,7 +195,7 @@ def _lower(self): ) # Blockwise merge - return BlockwiseMerge(left, right, **self.kwargs, _meta=self._meta) + return BlockwiseMerge(left, right, **self.kwargs, _precomputed_meta=self._meta) def _simplify_up(self, parent): if isinstance(parent, (Projection, Index)): @@ -250,7 +254,7 @@ def _simplify_up(self, parent): left[project_left], right[project_right], *self.operands[2:-1], - _meta=self._meta[meta_cols], + _precomputed_meta=self._meta[meta_cols], ) if parent_columns is None: return type(parent)(result) @@ -288,7 +292,7 @@ def _lower(self): @functools.cached_property def _meta(self): - return self.operand("_meta") + return self._precomputed_meta def _layer(self) -> dict: dsk = {} From aabc2ef41b63eacfc92599fc7cbba3971a0727cf Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 30 Aug 2023 16:20:44 +0200 Subject: [PATCH 5/5] Fix commit issue --- dask_expr/_merge.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index af1dad04..759b3c9e 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -253,7 +253,7 @@ def _simplify_up(self, parent): result = type(self)( left[project_left], right[project_right], - *self.operands[2:-1], + *self.operands[2:], _precomputed_meta=self._meta[meta_cols], ) if parent_columns is None: @@ -273,7 +273,6 @@ class HashJoinP2P(Merge, PartitionsFiltered): "suffixes", "indicator", "_partitions", - "_meta", ] _defaults = { "how": "inner", @@ -284,7 +283,6 @@ class HashJoinP2P(Merge, PartitionsFiltered): "suffixes": ("_x", "_y"), "indicator": False, "_partitions": None, - "_meta": None, } def _lower(self):