Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support file splitting in ReadParquetPyarrowFS #1139

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
12 changes: 10 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5220,6 +5220,12 @@ def read_parquet(

>>> dask.config.set({"dataframe.parquet.minimum-partition-size": "100MB"})

When ``filesystem="arrow"``, the Optimizer will also use a maximum size
per partition (default 256MB) to avoid over-sized partitions. This
configuration can be set with

>>> dask.config.set({"dataframe.parquet.maximum-partition-size": "512MB"})

.. note::
Specifying ``filesystem="arrow"`` leverages a complete reimplementation of
the Parquet reader that is solely based on PyArrow. It is significantly faster
Expand Down Expand Up @@ -5434,11 +5440,13 @@ def read_parquet(
)
if blocksize is not None and blocksize != "default":
raise NotImplementedError(
"blocksize is not supported when using the pyarrow filesystem."
"blocksize is not supported when using the pyarrow filesystem. "
"Please use the 'dataframe.parquet.maximim-partition-size' config."
)
if aggregate_files is not None:
raise NotImplementedError(
"aggregate_files is not supported when using the pyarrow filesystem."
"aggregate_files is not supported when using the pyarrow filesystem. "
"Please use the 'dataframe.parquet.minimim-partition-size' config."
)
if parquet_file_extension != (".parq", ".parquet", ".pq"):
raise NotImplementedError(
Expand Down
86 changes: 86 additions & 0 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,92 @@ def _task(self, index: int):
)


class SplitParquetIO(PartitionsFiltered, BlockwiseIO):
_parameters = ["_expr", "_partitions"]
_defaults = {"_partitions": None}

@functools.cached_property
def _name(self):
return (
self.operand("_expr")._funcname
+ "-split-"
+ _tokenize_deterministic(*self.operands)
)

@functools.cached_property
def _meta(self):
return self.operand("_expr")._meta

def dependencies(self):
return []

@property
def npartitions(self):
if self._filtered:
return len(self._partitions)
return len(self._split_mapping)

def _divisions(self):
# TODO: Handle this?
return (None,) * (len(self._split_mapping) + 1)

@staticmethod
def _load_partial_fragment(
local_split_index,
local_split_count,
frag,
filter,
columns,
schema,
*to_pandas_args,
):
from dask_expr.io.parquet import ReadParquetPyarrowFS

return ReadParquetPyarrowFS._table_to_pandas(
ReadParquetPyarrowFS._partial_fragment_to_table(
frag,
local_split_index,
local_split_count,
filter,
columns,
schema,
),
*to_pandas_args,
)

def _filtered_task(self, index: int):
expr = self.operand("_expr")
original_index, local_split_index = self._split_mapping[index]
_, frag_to_table, *to_pandas_args = expr._task(original_index)
return (
self._load_partial_fragment,
local_split_index,
self._local_split_count,
frag_to_table[1], # frag
frag_to_table[2], # filter
frag_to_table[3], # columns
frag_to_table[4], # schema
*to_pandas_args,
)

@functools.cached_property
def _local_split_count(self):
return self.operand("_expr")._split_division_factor

@functools.cached_property
def _split_mapping(self):
count = 0
mapping = {}
for op in self.operand("_expr")._partitions:
for s in range(self._local_split_count):
mapping[count] = (op, s) # original partition id, local split index
count += 1
return mapping

def _tune_up(self, parent):
return


class FromMap(PartitionsFiltered, BlockwiseIO):
_parameters = [
"func",
Expand Down
61 changes: 56 additions & 5 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import itertools
import math
import operator
import os
import pickle
Expand Down Expand Up @@ -60,7 +61,7 @@
from dask_expr._reductions import Len
from dask_expr._util import _convert_to_list, _tokenize_deterministic
from dask_expr.io import BlockwiseIO, PartitionsFiltered
from dask_expr.io.io import FusedParquetIO
from dask_expr.io.io import FusedParquetIO, SplitParquetIO


@normalize_token.register(pa.fs.FileInfo)
Expand Down Expand Up @@ -1037,6 +1038,7 @@ def _dataset_info(self):
dataset_info["using_metadata_file"] = True
dataset_info["fragments"] = _frags = list(dataset.get_fragments())
dataset_info["file_sizes"] = [None for fi in _frags]
dataset_info["all_files"] = all_files

if checksum is None:
checksum = tokenize(all_files)
Expand Down Expand Up @@ -1094,11 +1096,13 @@ def _divisions(self):
return self._division_from_stats[0]

def _tune_up(self, parent):
if self._fusion_compression_factor >= 1:
if isinstance(parent, (FusedParquetIO, SplitParquetIO)):
return
if isinstance(parent, FusedParquetIO):
return
return parent.substitute(self, FusedParquetIO(self))
if self._split_division_factor > 1:
return parent.substitute(self, SplitParquetIO(self))
if self._fusion_compression_factor < 1:
return parent.substitute(self, FusedParquetIO(self))
return

@cached_property
def fragments(self):
Expand Down Expand Up @@ -1150,6 +1154,24 @@ def _fusion_compression_factor(self):
total_uncompressed = max(total_uncompressed, min_size)
return max(after_projection / total_uncompressed, 0.001)

@property
def _split_division_factor(self) -> int:
approx_stats = self.approx_statistics()
after_projection = 0
col_op = self.operand("columns") or self.columns
for col in approx_stats["columns"]:
if col["path_in_schema"] in col_op:
after_projection += col["total_uncompressed_size"]

max_size = parse_bytes(
dask.config.get("dataframe.parquet.maximum-partition-size", "256 MB")
)
if after_projection <= max_size:
return 1

max_splits = max(math.floor(approx_stats["num_row_groups"]), 1)
return min(math.ceil(after_projection / max_size), max_splits)

def _filtered_task(self, index: int):
columns = self.columns.copy()
index_name = self.index.name
Expand All @@ -1175,6 +1197,35 @@ def _filtered_task(self, index: int):
self.pyarrow_strings_enabled,
)

@classmethod
def _partial_fragment_to_table(
cls,
fragment_wrapper,
local_split_index,
local_split_count,
filters,
columns,
schema,
):
if isinstance(fragment_wrapper, FragmentWrapper):
fragment = fragment_wrapper.fragment
else:
fragment = fragment_wrapper

num_row_groups = fragment.num_row_groups
stride = max(math.floor(num_row_groups / local_split_count), 1)
offset = local_split_index * stride
row_groups = list(range(offset, min(offset + stride, num_row_groups)))
assert row_groups # TODO: Handle empty partition case
fragment = fragment.format.make_fragment(
fragment.path,
fragment.filesystem,
fragment.partition_expression,
row_groups=row_groups,
)

return cls._fragment_to_table(fragment, filters, columns, schema)

@staticmethod
def _fragment_to_table(fragment_wrapper, filters, columns, schema):
_maybe_adjust_cpu_count()
Expand Down
14 changes: 13 additions & 1 deletion dask_expr/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,19 @@ def test_pyarrow_filesystem(parquet_file):

df_pa = read_parquet(parquet_file, filesystem=filesystem)
df = read_parquet(parquet_file)
assert assert_eq(df, df_pa)
assert_eq(df, df_pa)


def test_pyarrow_filesystem_max_partition_size(tmpdir):
with dask.config.set({"dataframe.parquet.maximum-partition-size": 1}):
pdf = pd.DataFrame({c: range(10) for c in "abcde"})
fn = _make_file(tmpdir, df=pdf, engine="pyarrow", row_group_size=1)
df = read_parquet(fn, filesystem="pyarrow")

# Trigger "_tune_up" optimization
df = df.map_partitions(lambda x: x)
assert df.optimize().npartitions == len(pdf)
assert_eq(df, pdf, check_index=False)


@pytest.mark.parametrize("dtype_backend", ["pyarrow", "numpy_nullable", None])
Expand Down
Loading