diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 326f1f49..9d7f7c2b 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -5220,6 +5220,12 @@ def read_parquet( >>> dask.config.set({"dataframe.parquet.minimum-partition-size": "100MB"}) + When ``filesystem="arrow"``, the Optimizer will also use a maximum size + per partition (default 256MB) to avoid over-sized partitions. This + configuration can be set with + + >>> dask.config.set({"dataframe.parquet.maximum-partition-size": "512MB"}) + .. note:: Specifying ``filesystem="arrow"`` leverages a complete reimplementation of the Parquet reader that is solely based on PyArrow. It is significantly faster @@ -5434,11 +5440,13 @@ def read_parquet( ) if blocksize is not None and blocksize != "default": raise NotImplementedError( - "blocksize is not supported when using the pyarrow filesystem." + "blocksize is not supported when using the pyarrow filesystem. " + "Please use the 'dataframe.parquet.maximim-partition-size' config." ) if aggregate_files is not None: raise NotImplementedError( - "aggregate_files is not supported when using the pyarrow filesystem." + "aggregate_files is not supported when using the pyarrow filesystem. " + "Please use the 'dataframe.parquet.minimim-partition-size' config." ) if parquet_file_extension != (".parq", ".parquet", ".pq"): raise NotImplementedError( diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index a28b8762..e226af40 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -200,6 +200,92 @@ def _task(self, index: int): ) +class SplitParquetIO(PartitionsFiltered, BlockwiseIO): + _parameters = ["_expr", "_partitions"] + _defaults = {"_partitions": None} + + @functools.cached_property + def _name(self): + return ( + self.operand("_expr")._funcname + + "-split-" + + _tokenize_deterministic(*self.operands) + ) + + @functools.cached_property + def _meta(self): + return self.operand("_expr")._meta + + def dependencies(self): + return [] + + @property + def npartitions(self): + if self._filtered: + return len(self._partitions) + return len(self._split_mapping) + + def _divisions(self): + # TODO: Handle this? + return (None,) * (len(self._split_mapping) + 1) + + @staticmethod + def _load_partial_fragment( + local_split_index, + local_split_count, + frag, + filter, + columns, + schema, + *to_pandas_args, + ): + from dask_expr.io.parquet import ReadParquetPyarrowFS + + return ReadParquetPyarrowFS._table_to_pandas( + ReadParquetPyarrowFS._partial_fragment_to_table( + frag, + local_split_index, + local_split_count, + filter, + columns, + schema, + ), + *to_pandas_args, + ) + + def _filtered_task(self, index: int): + expr = self.operand("_expr") + original_index, local_split_index = self._split_mapping[index] + _, frag_to_table, *to_pandas_args = expr._task(original_index) + return ( + self._load_partial_fragment, + local_split_index, + self._local_split_count, + frag_to_table[1], # frag + frag_to_table[2], # filter + frag_to_table[3], # columns + frag_to_table[4], # schema + *to_pandas_args, + ) + + @functools.cached_property + def _local_split_count(self): + return self.operand("_expr")._split_division_factor + + @functools.cached_property + def _split_mapping(self): + count = 0 + mapping = {} + for op in self.operand("_expr")._partitions: + for s in range(self._local_split_count): + mapping[count] = (op, s) # original partition id, local split index + count += 1 + return mapping + + def _tune_up(self, parent): + return + + class FromMap(PartitionsFiltered, BlockwiseIO): _parameters = [ "func", diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 60c7400c..69c6299a 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -2,6 +2,7 @@ import contextlib import itertools +import math import operator import os import pickle @@ -60,7 +61,7 @@ from dask_expr._reductions import Len from dask_expr._util import _convert_to_list, _tokenize_deterministic from dask_expr.io import BlockwiseIO, PartitionsFiltered -from dask_expr.io.io import FusedParquetIO +from dask_expr.io.io import FusedParquetIO, SplitParquetIO @normalize_token.register(pa.fs.FileInfo) @@ -1037,6 +1038,7 @@ def _dataset_info(self): dataset_info["using_metadata_file"] = True dataset_info["fragments"] = _frags = list(dataset.get_fragments()) dataset_info["file_sizes"] = [None for fi in _frags] + dataset_info["all_files"] = all_files if checksum is None: checksum = tokenize(all_files) @@ -1094,11 +1096,13 @@ def _divisions(self): return self._division_from_stats[0] def _tune_up(self, parent): - if self._fusion_compression_factor >= 1: + if isinstance(parent, (FusedParquetIO, SplitParquetIO)): return - if isinstance(parent, FusedParquetIO): - return - return parent.substitute(self, FusedParquetIO(self)) + if self._split_division_factor > 1: + return parent.substitute(self, SplitParquetIO(self)) + if self._fusion_compression_factor < 1: + return parent.substitute(self, FusedParquetIO(self)) + return @cached_property def fragments(self): @@ -1150,6 +1154,24 @@ def _fusion_compression_factor(self): total_uncompressed = max(total_uncompressed, min_size) return max(after_projection / total_uncompressed, 0.001) + @property + def _split_division_factor(self) -> int: + approx_stats = self.approx_statistics() + after_projection = 0 + col_op = self.operand("columns") or self.columns + for col in approx_stats["columns"]: + if col["path_in_schema"] in col_op: + after_projection += col["total_uncompressed_size"] + + max_size = parse_bytes( + dask.config.get("dataframe.parquet.maximum-partition-size", "256 MB") + ) + if after_projection <= max_size: + return 1 + + max_splits = max(math.floor(approx_stats["num_row_groups"]), 1) + return min(math.ceil(after_projection / max_size), max_splits) + def _filtered_task(self, index: int): columns = self.columns.copy() index_name = self.index.name @@ -1175,6 +1197,35 @@ def _filtered_task(self, index: int): self.pyarrow_strings_enabled, ) + @classmethod + def _partial_fragment_to_table( + cls, + fragment_wrapper, + local_split_index, + local_split_count, + filters, + columns, + schema, + ): + if isinstance(fragment_wrapper, FragmentWrapper): + fragment = fragment_wrapper.fragment + else: + fragment = fragment_wrapper + + num_row_groups = fragment.num_row_groups + stride = max(math.floor(num_row_groups / local_split_count), 1) + offset = local_split_index * stride + row_groups = list(range(offset, min(offset + stride, num_row_groups))) + assert row_groups # TODO: Handle empty partition case + fragment = fragment.format.make_fragment( + fragment.path, + fragment.filesystem, + fragment.partition_expression, + row_groups=row_groups, + ) + + return cls._fragment_to_table(fragment, filters, columns, schema) + @staticmethod def _fragment_to_table(fragment_wrapper, filters, columns, schema): _maybe_adjust_cpu_count() diff --git a/dask_expr/io/tests/test_parquet.py b/dask_expr/io/tests/test_parquet.py index 3d111d5d..e2c7e85d 100644 --- a/dask_expr/io/tests/test_parquet.py +++ b/dask_expr/io/tests/test_parquet.py @@ -119,7 +119,19 @@ def test_pyarrow_filesystem(parquet_file): df_pa = read_parquet(parquet_file, filesystem=filesystem) df = read_parquet(parquet_file) - assert assert_eq(df, df_pa) + assert_eq(df, df_pa) + + +def test_pyarrow_filesystem_max_partition_size(tmpdir): + with dask.config.set({"dataframe.parquet.maximum-partition-size": 1}): + pdf = pd.DataFrame({c: range(10) for c in "abcde"}) + fn = _make_file(tmpdir, df=pdf, engine="pyarrow", row_group_size=1) + df = read_parquet(fn, filesystem="pyarrow") + + # Trigger "_tune_up" optimization + df = df.map_partitions(lambda x: x) + assert df.optimize().npartitions == len(pdf) + assert_eq(df, pdf, check_index=False) @pytest.mark.parametrize("dtype_backend", ["pyarrow", "numpy_nullable", None])