diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 1782c903..360320fc 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -1,4 +1,5 @@ import functools +import itertools import math import operator @@ -352,6 +353,46 @@ def _lower(self): return BlockwiseMerge(left, right, **self.kwargs) def _simplify_up(self, parent, dependents): + if isinstance(parent, Merge): + # TODO: Figuring out what we can rewrite and how is some work + if parent.how == self.how == "inner": + parent = ( + (parent.left_on, parent.left, 0) + if parent.right is self + else (parent.right_on, parent.right, 0) + ) + all_frames = [ + parent, + (self.left_on, self.left, 1), + (self.right_on, self.right, 2), + ] + + def cost(join_tuple): + # This should be a more general cardinality esimate of the + # resulting join but for now just sort them by size + (_, left, _), (_, right, _) = join_tuple + return (right.npartitions, left.npartitions) + + first_merge = min(itertools.permutations(all_frames, 2), key=cost) + picked = {first_merge[0][2], first_merge[1][2]} + last_merge = next( + (f for f in all_frames if f[2] not in picked), first_merge[0] + ) + new_right = Merge( + first_merge[1][1], + first_merge[0][1], + how=self.how, + left_on=first_merge[1][0], + right_on=first_merge[0][0], + # FIXME: we loose params here + ) + return Merge( + last_merge[1], + new_right, + how=self.how, + left_on=last_merge[0], + right_on=first_merge[1][0], + ) if isinstance(parent, (Projection, Index)): # Reorder the column projection to # occur before the Merge diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 1e99449a..ddba9db1 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -643,3 +643,24 @@ def test_pairwise_merge_results_in_identical_output_df( # recursive join doesn't yet respect divisions in dask-expr assert_eq(ddf_pairwise, ddf_loop) + + +def test_join_reorder(): + pdf1 = pd.DataFrame({"x": range(100), "a": range(100)}) + df1 = from_pandas(pdf1, 10) + pdf2 = pd.DataFrame({"x": range(50), "c": range(50)}) + df2 = from_pandas(pdf2, 4) + pdf3 = pd.DataFrame({"x": range(40, 60), "b": range(20)}) + df3 = from_pandas(pdf3, 2) + + expected_pdf = pdf1.merge(pdf2).merge(pdf3) + expected_pdf2 = pdf3.merge(pdf2).merge(pdf1) + actual = df1.merge(df2).merge(df3) + expected = df3.merge(df2).merge(df1) + + assert actual.simplify()._name == expected.simplify()._name + cols = expected_pdf.columns + assert_eq(expected_pdf2[cols], expected_pdf) + # FIXME: Col order is optimized away. Therefore compute and sort the columns + assert_eq(actual.compute()[cols], expected_pdf, check_index=False) + assert_eq(expected.compute()[cols], expected_pdf, check_index=False)