Skip to content

Commit

Permalink
Add multi-partition Scan support to cuDF-Polars (#17494)
Browse files Browse the repository at this point in the history
Adds multi-partition `Scan` support following the same design as #17441

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #17494
  • Loading branch information
rjzamora authored Dec 19, 2024
1 parent 989fac4 commit 253b0d8
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 4 deletions.
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def validate_config_options(config: dict) -> None:
executor = config.get("executor", "pylibcudf")
if executor == "dask-experimental":
unsupported = config.get("executor_options", {}).keys() - {
"max_rows_per_partition"
"max_rows_per_partition",
"parquet_blocksize",
}
else:
unsupported = config.get("executor_options", {}).keys()
Expand Down
283 changes: 280 additions & 3 deletions python/cudf_polars/cudf_polars/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@

from __future__ import annotations

import enum
import math
from typing import TYPE_CHECKING
import random
from enum import IntEnum
from typing import TYPE_CHECKING, Any

from cudf_polars.dsl.ir import DataFrameScan, Union
import pylibcudf as plc

from cudf_polars.dsl.ir import IR, DataFrameScan, Scan, Union
from cudf_polars.experimental.base import PartitionInfo
from cudf_polars.experimental.dispatch import lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.ir import IR
from cudf_polars.dsl.expr import NamedExpr
from cudf_polars.experimental.dispatch import LowerIRTransformer
from cudf_polars.typing import Schema


@lower_ir_node.register(DataFrameScan)
Expand Down Expand Up @@ -47,3 +53,274 @@ def _(
}

return ir, {ir: PartitionInfo(count=1)}


class ScanPartitionFlavor(IntEnum):
"""Flavor of Scan partitioning."""

SINGLE_FILE = enum.auto() # 1:1 mapping between files and partitions
SPLIT_FILES = enum.auto() # Split each file into >1 partition
FUSED_FILES = enum.auto() # Fuse multiple files into each partition


class ScanPartitionPlan:
"""
Scan partitioning plan.
Notes
-----
The meaning of `factor` depends on the value of `flavor`:
- SINGLE_FILE: `factor` must be `1`.
- SPLIT_FILES: `factor` is the number of partitions per file.
- FUSED_FILES: `factor` is the number of files per partition.
"""

__slots__ = ("factor", "flavor")
factor: int
flavor: ScanPartitionFlavor

def __init__(self, factor: int, flavor: ScanPartitionFlavor) -> None:
if (
flavor == ScanPartitionFlavor.SINGLE_FILE and factor != 1
): # pragma: no cover
raise ValueError(f"Expected factor == 1 for {flavor}, got: {factor}")
self.factor = factor
self.flavor = flavor

@staticmethod
def from_scan(ir: Scan) -> ScanPartitionPlan:
"""Extract the partitioning plan of a Scan operation."""
if ir.typ == "parquet":
# TODO: Use system info to set default blocksize
parallel_options = ir.config_options.get("executor_options", {})
blocksize: int = parallel_options.get("parquet_blocksize", 1024**3)
stats = _sample_pq_statistics(ir)
file_size = sum(float(stats[column]) for column in ir.schema)
if file_size > 0:
if file_size > blocksize:
# Split large files
return ScanPartitionPlan(
math.ceil(file_size / blocksize),
ScanPartitionFlavor.SPLIT_FILES,
)
else:
# Fuse small files
return ScanPartitionPlan(
max(blocksize // int(file_size), 1),
ScanPartitionFlavor.FUSED_FILES,
)

# TODO: Use file sizes for csv and json
return ScanPartitionPlan(1, ScanPartitionFlavor.SINGLE_FILE)


class SplitScan(IR):
"""
Input from a split file.
This class wraps a single-file `Scan` object. At
IO/evaluation time, this class will only perform
a partial read of the underlying file. The range
(skip_rows and n_rows) is calculated at IO time.
"""

__slots__ = (
"base_scan",
"schema",
"split_index",
"total_splits",
)
_non_child = (
"schema",
"base_scan",
"split_index",
"total_splits",
)
base_scan: Scan
"""Scan operation this node is based on."""
split_index: int
"""Index of the current split."""
total_splits: int
"""Total number of splits."""

def __init__(
self, schema: Schema, base_scan: Scan, split_index: int, total_splits: int
):
self.schema = schema
self.base_scan = base_scan
self.split_index = split_index
self.total_splits = total_splits
self._non_child_args = (
split_index,
total_splits,
*base_scan._non_child_args,
)
self.children = ()
if base_scan.typ not in ("parquet",): # pragma: no cover
raise NotImplementedError(
f"Unhandled Scan type for file splitting: {base_scan.typ}"
)

@classmethod
def do_evaluate(
cls,
split_index: int,
total_splits: int,
schema: Schema,
typ: str,
reader_options: dict[str, Any],
config_options: dict[str, Any],
paths: list[str],
with_columns: list[str] | None,
skip_rows: int,
n_rows: int,
row_index: tuple[str, int] | None,
predicate: NamedExpr | None,
):
"""Evaluate and return a dataframe."""
if typ not in ("parquet",): # pragma: no cover
raise NotImplementedError(f"Unhandled Scan type for file splitting: {typ}")

if len(paths) > 1: # pragma: no cover
raise ValueError(f"Expected a single path, got: {paths}")

# Parquet logic:
# - We are one of "total_splits" SplitScan nodes
# assigned to the same file.
# - We know our index within this file ("split_index")
# - We can also use parquet metadata to query the
# total number of rows in each row-group of the file.
# - We can use all this information to calculate the
# "skip_rows" and "n_rows" options to use locally.

rowgroup_metadata = plc.io.parquet_metadata.read_parquet_metadata(
plc.io.SourceInfo(paths)
).rowgroup_metadata()
total_row_groups = len(rowgroup_metadata)
if total_splits <= total_row_groups:
# We have enough row-groups in the file to align
# all "total_splits" of our reads with row-group
# boundaries. Calculate which row-groups to include
# in the current read, and use metadata to translate
# the row-group indices to "skip_rows" and "n_rows".
rg_stride = total_row_groups // total_splits
skip_rgs = rg_stride * split_index
skip_rows = sum(rg["num_rows"] for rg in rowgroup_metadata[:skip_rgs])
n_rows = sum(
rg["num_rows"]
for rg in rowgroup_metadata[skip_rgs : skip_rgs + rg_stride]
)
else:
# There are not enough row-groups to align
# all "total_splits" of our reads with row-group
# boundaries. Use metadata to directly calculate
# "skip_rows" and "n_rows" for the current read.
total_rows = sum(rg["num_rows"] for rg in rowgroup_metadata)
n_rows = total_rows // total_splits
skip_rows = n_rows * split_index

# Last split should always read to end of file
if split_index == (total_splits - 1):
n_rows = -1

# Perform the partial read
return Scan.do_evaluate(
schema,
typ,
reader_options,
config_options,
paths,
with_columns,
skip_rows,
n_rows,
row_index,
predicate,
)


def _sample_pq_statistics(ir: Scan) -> dict[str, float]:
import numpy as np
import pyarrow.dataset as pa_ds

# Use average total_uncompressed_size of three files
# TODO: Use plc.io.parquet_metadata.read_parquet_metadata
n_sample = 3
column_sizes = {}
ds = pa_ds.dataset(random.sample(ir.paths, n_sample), format="parquet")
for i, frag in enumerate(ds.get_fragments()):
md = frag.metadata
for rg in range(md.num_row_groups):
row_group = md.row_group(rg)
for col in range(row_group.num_columns):
column = row_group.column(col)
name = column.path_in_schema
if name not in column_sizes:
column_sizes[name] = np.zeros(n_sample, dtype="int64")
column_sizes[name][i] += column.total_uncompressed_size

return {name: np.mean(sizes) for name, sizes in column_sizes.items()}


@lower_ir_node.register(Scan)
def _(
ir: Scan, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
partition_info: MutableMapping[IR, PartitionInfo]
if ir.typ in ("csv", "parquet", "ndjson") and ir.n_rows == -1 and ir.skip_rows == 0:
plan = ScanPartitionPlan.from_scan(ir)
paths = list(ir.paths)
if plan.flavor == ScanPartitionFlavor.SPLIT_FILES:
# Disable chunked reader when splitting files
config_options = ir.config_options.copy()
config_options["parquet_options"] = config_options.get(
"parquet_options", {}
).copy()
config_options["parquet_options"]["chunked"] = False

slices: list[SplitScan] = []
for path in paths:
base_scan = Scan(
ir.schema,
ir.typ,
ir.reader_options,
ir.cloud_options,
config_options,
[path],
ir.with_columns,
ir.skip_rows,
ir.n_rows,
ir.row_index,
ir.predicate,
)
slices.extend(
SplitScan(ir.schema, base_scan, sindex, plan.factor)
for sindex in range(plan.factor)
)
new_node = Union(ir.schema, None, *slices)
partition_info = {slice: PartitionInfo(count=1) for slice in slices} | {
new_node: PartitionInfo(count=len(slices))
}
else:
groups: list[Scan] = [
Scan(
ir.schema,
ir.typ,
ir.reader_options,
ir.cloud_options,
ir.config_options,
paths[i : i + plan.factor],
ir.with_columns,
ir.skip_rows,
ir.n_rows,
ir.row_index,
ir.predicate,
)
for i in range(0, len(paths), plan.factor)
]
new_node = Union(ir.schema, None, *groups)
partition_info = {group: PartitionInfo(count=1) for group in groups} | {
new_node: PartitionInfo(count=len(groups))
}
return new_node, partition_info

return ir, {ir: PartitionInfo(count=1)} # pragma: no cover
80 changes: 80 additions & 0 deletions python/cudf_polars/tests/experimental/test_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars import Translator
from cudf_polars.experimental.parallel import lower_ir_graph
from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(scope="module")
def df():
return pl.DataFrame(
{
"x": range(3_000),
"y": ["cat", "dog", "fish"] * 1_000,
"z": [1.0, 2.0, 3.0, 4.0, 5.0] * 600,
}
)


def make_source(df, path, fmt, n_files=3):
n_rows = len(df)
stride = int(n_rows / n_files)
for i in range(n_files):
offset = stride * i
part = df.slice(offset, stride)
if fmt == "csv":
part.write_csv(path / f"part.{i}.csv")
elif fmt == "ndjson":
part.write_ndjson(path / f"part.{i}.ndjson")
else:
part.write_parquet(
path / f"part.{i}.parquet",
row_group_size=int(stride / 2),
)


@pytest.mark.parametrize(
"fmt, scan_fn",
[
("csv", pl.scan_csv),
("ndjson", pl.scan_ndjson),
("parquet", pl.scan_parquet),
],
)
def test_parallel_scan(tmp_path, df, fmt, scan_fn):
make_source(df, tmp_path, fmt)
q = scan_fn(tmp_path)
engine = pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
)
assert_gpu_result_equal(q, engine=engine)


@pytest.mark.parametrize("blocksize", [1_000, 10_000, 1_000_000])
def test_parquet_blocksize(tmp_path, df, blocksize):
n_files = 3
make_source(df, tmp_path, "parquet", n_files)
q = pl.scan_parquet(tmp_path)
engine = pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"parquet_blocksize": blocksize},
)
assert_gpu_result_equal(q, engine=engine)

# Check partitioning
qir = Translator(q._ldf.visit(), engine).translate_ir()
ir, info = lower_ir_graph(qir)
count = info[ir].count
if blocksize <= 12_000:
assert count > n_files
else:
assert count < n_files

0 comments on commit 253b0d8

Please sign in to comment.