diff --git a/.asf.yaml b/.asf.yaml index bd063d4bbf4a..366c719597aa 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -38,7 +38,7 @@ github: features: issues: true protected_branches: - master: + main: required_status_checks: # require branches to be up-to-date before merging strict: true diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index aa1d1d9c14da..20da777ec0e5 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -30,7 +30,7 @@ runs: using: "composite" steps: - name: Cache Cargo - uses: actions/cache@v3 + uses: actions/cache@v4 with: # these represent dependencies downloaded by cargo # and thus do not depend on the OS, arch nor rust version. diff --git a/.github/dependabot.yml b/.github/dependabot.yml index ffde5378da93..b22c01f8a1b9 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,14 +5,14 @@ updates: schedule: interval: daily open-pull-requests-limit: 10 - target-branch: master + target-branch: main labels: [ auto-dependencies, arrow ] - package-ecosystem: cargo directory: "/object_store" schedule: interval: daily open-pull-requests-limit: 10 - target-branch: master + target-branch: main labels: [ auto-dependencies, object_store ] - package-ecosystem: "github-actions" directory: "/" diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 679ccc956a20..08bdf123f4d6 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -20,6 +20,6 @@ The CI is structured so most tests are run in specific workflows: `arrow.yml` for `arrow`, `parquet.yml` for `parquet` and so on. -The basic idea is to run all tests on pushes to master (to ensure we -keep master green) but run only the individual workflows on PRs that +The basic idea is to run all tests on pushes to main (to ensure we +keep main green) but run only the individual workflows on PRs that change files that could affect them. diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index d3b2526740fa..daf38f2523fc 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -26,7 +26,7 @@ on: # always trigger push: branches: - - master + - main pull_request: paths: - .github/** @@ -61,39 +61,39 @@ jobs: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - - name: Test arrow-buffer with all features + - name: Test arrow-buffer run: cargo test -p arrow-buffer --all-features - - name: Test arrow-data with all features + - name: Test arrow-data run: cargo test -p arrow-data --all-features - - name: Test arrow-schema with all features + - name: Test arrow-schema run: cargo test -p arrow-schema --all-features - - name: Test arrow-array with all features + - name: Test arrow-array run: cargo test -p arrow-array --all-features - - name: Test arrow-select with all features + - name: Test arrow-select run: cargo test -p arrow-select --all-features - - name: Test arrow-cast with all features + - name: Test arrow-cast run: cargo test -p arrow-cast --all-features - - name: Test arrow-ipc with all features + - name: Test arrow-ipc run: cargo test -p arrow-ipc --all-features - - name: Test arrow-csv with all features + - name: Test arrow-csv run: cargo test -p arrow-csv --all-features - - name: Test arrow-json with all features + - name: Test arrow-json run: cargo test -p arrow-json --all-features - - name: Test arrow-avro with all features + - name: Test arrow-avro run: cargo test -p arrow-avro --all-features - - name: Test arrow-string with all features + - name: Test arrow-string run: cargo test -p arrow-string --all-features - - name: Test arrow-ord with all features + - name: Test arrow-ord run: cargo test -p arrow-ord --all-features - - name: Test arrow-arith with all features + - name: Test arrow-arith run: cargo test -p arrow-arith --all-features - - name: Test arrow-row with all features + - name: Test arrow-row run: cargo test -p arrow-row --all-features - - name: Test arrow-integration-test with all features + - name: Test arrow-integration-test run: cargo test -p arrow-integration-test --all-features - name: Test arrow with default features run: cargo test -p arrow - - name: Test arrow with all features except pyarrow + - name: Test arrow except pyarrow run: cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,chrono-tz - name: Run examples run: | @@ -163,37 +163,139 @@ jobs: uses: ./.github/actions/setup-builder - name: Setup Clippy run: rustup component add clippy - - name: Clippy arrow-buffer with all features - run: cargo clippy -p arrow-buffer --all-targets --all-features -- -D warnings - - name: Clippy arrow-data with all features - run: cargo clippy -p arrow-data --all-targets --all-features -- -D warnings - - name: Clippy arrow-schema with all features - run: cargo clippy -p arrow-schema --all-targets --all-features -- -D warnings - - name: Clippy arrow-array with all features - run: cargo clippy -p arrow-array --all-targets --all-features -- -D warnings - - name: Clippy arrow-select with all features - run: cargo clippy -p arrow-select --all-targets --all-features -- -D warnings - - name: Clippy arrow-cast with all features - run: cargo clippy -p arrow-cast --all-targets --all-features -- -D warnings - - name: Clippy arrow-ipc with all features - run: cargo clippy -p arrow-ipc --all-targets --all-features -- -D warnings - - name: Clippy arrow-csv with all features - run: cargo clippy -p arrow-csv --all-targets --all-features -- -D warnings - - name: Clippy arrow-json with all features - run: cargo clippy -p arrow-json --all-targets --all-features -- -D warnings - - name: Clippy arrow-avro with all features - run: cargo clippy -p arrow-avro --all-targets --all-features -- -D warnings - - name: Clippy arrow-string with all features - run: cargo clippy -p arrow-string --all-targets --all-features -- -D warnings - - name: Clippy arrow-ord with all features - run: cargo clippy -p arrow-ord --all-targets --all-features -- -D warnings - - name: Clippy arrow-arith with all features - run: cargo clippy -p arrow-arith --all-targets --all-features -- -D warnings - - name: Clippy arrow-row with all features - run: cargo clippy -p arrow-row --all-targets --all-features -- -D warnings - - name: Clippy arrow with all features - run: cargo clippy -p arrow --all-features --all-targets -- -D warnings - - name: Clippy arrow-integration-test with all features - run: cargo clippy -p arrow-integration-test --all-targets --all-features -- -D warnings - - name: Clippy arrow-integration-testing with all features - run: cargo clippy -p arrow-integration-testing --all-targets --all-features -- -D warnings + - name: Clippy arrow-buffer + run: | + mod=arrow-buffer + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-data + run: | + mod=arrow-data + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-schema + run: | + mod=arrow-schema + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-array + run: | + mod=arrow-array + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-select + run: | + mod=arrow-select + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-cast + run: | + mod=arrow-cast + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-ipc + run: | + mod=arrow-ipc + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-csv + run: | + mod=arrow-csv + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-json + run: | + mod=arrow-json + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-avro + run: | + mod=arrow-avro + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-string + run: | + mod=arrow-string + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-ord + run: | + mod=arrow-ord + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-arith + run: | + mod=arrow-arith + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-row + run: | + mod=arrow-row + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow + run: | + mod=arrow + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-integration-test + run: | + mod=arrow-integration-test + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies + - name: Clippy arrow-integration-testing + run: | + mod=arrow-integration-testing + cargo clippy -p "$mod" --all-targets --all-features -- -D warnings + # Dependency checks excluding tests & benches. + cargo clippy -p "$mod" -- -D unused_crate_dependencies + cargo clippy -p "$mod" --all-features -- -D unused_crate_dependencies + cargo clippy -p "$mod" --no-default-features -- -D unused_crate_dependencies diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml index 242e0f2a3b0d..79627448ca40 100644 --- a/.github/workflows/arrow_flight.yml +++ b/.github/workflows/arrow_flight.yml @@ -23,11 +23,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - arrow-array/** diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 2c1dcdfd2100..e6254ea24a58 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -21,11 +21,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - '**/Cargo.toml' diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 2026e257ab29..b28e8c20cfe7 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -21,11 +21,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs and changes to master +# trigger for all PRs and changes to main on: push: branches: - - master + - main pull_request: env: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 08d287bcceb2..d6ec0622f6ed 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,11 +21,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs and changes to master +# trigger for all PRs and changes to main on: push: branches: - - master + - main pull_request: jobs: @@ -70,8 +70,8 @@ jobs: path: target/doc deploy: - # Only deploy if a push to master - if: github.ref_name == 'master' && github.event_name == 'push' + # Only deploy if a push to main + if: github.ref_name == 'main' && github.event_name == 'push' needs: docs permissions: contents: write @@ -90,7 +90,7 @@ jobs: cp .asf.yaml ./website/build/.asf.yaml - name: Deploy to gh-pages uses: peaceiris/actions-gh-pages@v4.0.0 - if: github.event_name == 'push' && github.ref_name == 'master' + if: github.event_name == 'push' && github.ref_name == 'main' with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: website/build diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 868729a168e8..9b23b1b5ad2e 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -21,11 +21,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - .github/** diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index 19b432121b6f..ce67546a104b 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -21,11 +21,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - .github/** diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index 1857b330326a..899318f01324 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -23,11 +23,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - object_store/** @@ -54,6 +54,10 @@ jobs: # targets. - name: Run clippy with default features run: cargo clippy -- -D warnings + - name: Run clippy without default features + run: cargo clippy --no-default-features -- -D warnings + - name: Run clippy with fs features + run: cargo clippy --no-default-features --features fs -- -D warnings - name: Run clippy with aws feature run: cargo clippy --features aws -- -D warnings - name: Run clippy with gcp feature @@ -138,9 +142,10 @@ jobs: - name: Setup LocalStack (AWS emulation) run: | - echo "LOCALSTACK_CONTAINER=$(docker run -d -p 4566:4566 localstack/localstack:3.8.1)" >> $GITHUB_ENV + echo "LOCALSTACK_CONTAINER=$(docker run -d -p 4566:4566 localstack/localstack:4.0.3)" >> $GITHUB_ENV echo "EC2_METADATA_CONTAINER=$(docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2)" >> $GITHUB_ENV aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket + aws --endpoint-url=http://localhost:4566 s3api create-bucket --bucket test-object-lock --object-lock-enabled-for-bucket aws --endpoint-url=http://localhost:4566 dynamodb create-table --table-name test-table --key-schema AttributeName=path,KeyType=HASH AttributeName=etag,KeyType=RANGE --attribute-definitions AttributeName=path,AttributeType=S AttributeName=etag,AttributeType=S --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5 KMS_KEY=$(aws --endpoint-url=http://localhost:4566 kms create-key --description "test key") @@ -164,7 +169,7 @@ jobs: - name: Run object_store tests (AWS native conditional put) run: cargo test --features=aws env: - AWS_CONDITIONAL_PUT: etag-put-if-not-exists + AWS_CONDITIONAL_PUT: etag AWS_COPY_IF_NOT_EXISTS: multipart - name: GCS Output diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index a4e654892662..2269950fd235 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -23,11 +23,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - arrow/** diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml index d8b02f73a8aa..17aec724a820 100644 --- a/.github/workflows/parquet_derive.yml +++ b/.github/workflows/parquet_derive.yml @@ -23,11 +23,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs that touch certain files and changes to master +# trigger for all PRs that touch certain files and changes to main on: push: branches: - - master + - main pull_request: paths: - parquet/** diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1b65c5057de1..044250b70435 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -22,11 +22,11 @@ concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true -# trigger for all PRs and changes to master +# trigger for all PRs and changes to main on: push: branches: - - master + - main pull_request: jobs: @@ -101,18 +101,19 @@ jobs: - name: Format arrow run: cargo fmt --all -- --check - name: Format parquet - # Many modules in parquet are skipped, so check parquet separately. If this check fails, run: - # cargo fmt -p parquet -- --config skip_children=true `find ./parquet -name "*.rs" \! -name format.rs` - # from the top level arrow-rs directory and check in the result. + # Many modules in parquet are skipped, so check parquet separately # https://github.com/apache/arrow-rs/issues/6179 working-directory: parquet - run: cargo fmt -p parquet -- --check --config skip_children=true `find . -name "*.rs" \! -name format.rs` + run: | + # if this fails, run this from the parquet directory: + # cargo fmt -p parquet -- --config skip_children=true `find . -name "*.rs" \! -name format.rs` + cargo fmt -p parquet -- --check --config skip_children=true `find . -name "*.rs" \! -name format.rs` - name: Format object_store working-directory: object_store run: cargo fmt --all -- --check msrv: - name: Verify MSRV + name: Verify MSRV (Minimum Supported Rust Version) runs-on: ubuntu-latest container: image: amd64/rust @@ -126,13 +127,19 @@ jobs: run: cargo update -p ahash --precise 0.8.7 - name: Check arrow working-directory: arrow - run: cargo msrv --log-target stdout verify + run: | + # run `cd arrow; cargo msrv verify` to see problematic dependencies + cargo msrv verify --output-format=json - name: Check parquet working-directory: parquet - run: cargo msrv --log-target stdout verify + run: | + # run `cd parquet; cargo msrv verify` to see problematic dependencies + cargo msrv verify --output-format=json - name: Check arrow-flight working-directory: arrow-flight - run: cargo msrv --log-target stdout verify + run: | + # run `cd arrow-flight; cargo msrv verify` to see problematic dependencies + cargo msrv verify --output-format=json - name: Downgrade object_store dependencies working-directory: object_store # Necessary because tokio 1.30.0 updates MSRV to 1.63 @@ -142,4 +149,6 @@ jobs: cargo update -p url --precise 2.5.0 - name: Check object_store working-directory: object_store - run: cargo msrv --log-target stdout verify + run: | + # run `cd object_store; cargo msrv verify` to see problematic dependencies + cargo msrv verify --output-format=json diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index 5b3a3255ffcd..3fb17b390ac1 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -19,6 +19,239 @@ # Historical Changelog +## [53.3.0](https://github.com/apache/arrow-rs/tree/53.3.0) (2024-11-17) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/53.2.0...53.3.0) + +- Signed decimal e-notation parsing bug [\#6728](https://github.com/apache/arrow-rs/issues/6728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for Utf8View -\> numeric in can\_cast\_types [\#6715](https://github.com/apache/arrow-rs/issues/6715) +- IPC file writer produces incorrect footer when not preserving dict ID [\#6710](https://github.com/apache/arrow-rs/issues/6710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet from\_thrift\_helper incorrectly checks index [\#6693](https://github.com/apache/arrow-rs/issues/6693) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Primitive REPEATED fields not contained in LIST annotated groups aren't read as lists by record reader [\#6648](https://github.com/apache/arrow-rs/issues/6648) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- DictionaryHandling does not recurse into Map fields [\#6644](https://github.com/apache/arrow-rs/issues/6644) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Array writer output empty when no record is written [\#6613](https://github.com/apache/arrow-rs/issues/6613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Archery Integration Test with c\# failing on main [\#6577](https://github.com/apache/arrow-rs/issues/6577) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Potential unsoundness in `filter_run_end_array` [\#6569](https://github.com/apache/arrow-rs/issues/6569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet reader can generate incorrect validity buffer information for nested structures [\#6510](https://github.com/apache/arrow-rs/issues/6510) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- arrow-array ffi: FFI\_ArrowArray.null\_count is always interpreted as unsigned and initialized during conversion from C to Rust. [\#6497](https://github.com/apache/arrow-rs/issues/6497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Minor: Document pattern for accessing views in StringView [\#6673](https://github.com/apache/arrow-rs/pull/6673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve Array::is\_nullable documentation [\#6615](https://github.com/apache/arrow-rs/pull/6615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Minor: improve docs for ByteViewArray-\>ByteArray From impl [\#6610](https://github.com/apache/arrow-rs/pull/6610) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- Speed up `filter_run_end_array` [\#6712](https://github.com/apache/arrow-rs/pull/6712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) + +**Closed issues:** + +- Incorrect like results for pattern starting/ending with `%` percent and containing escape characters [\#6702](https://github.com/apache/arrow-rs/issues/6702) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix signed decimal e-notation parsing [\#6729](https://github.com/apache/arrow-rs/pull/6729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Clean up some arrow-flight tests and duplicated code [\#6725](https://github.com/apache/arrow-rs/pull/6725) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) +- Update PR template section about API breaking changes [\#6723](https://github.com/apache/arrow-rs/pull/6723) ([findepi](https://github.com/findepi)) +- Support for casting `StringViewArray` to `DecimalArray` [\#6720](https://github.com/apache/arrow-rs/pull/6720) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tlm365](https://github.com/tlm365)) +- File writer preserve dict bug [\#6711](https://github.com/apache/arrow-rs/pull/6711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Add filter\_kernel benchmark for run array [\#6706](https://github.com/apache/arrow-rs/pull/6706) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- Fix string view ILIKE checks with NULL values [\#6705](https://github.com/apache/arrow-rs/pull/6705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Implement logical\_null\_count for more array types [\#6704](https://github.com/apache/arrow-rs/pull/6704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Fix LIKE with escapes [\#6703](https://github.com/apache/arrow-rs/pull/6703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Speed up `filter_bytes` [\#6699](https://github.com/apache/arrow-rs/pull/6699) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Minor: fix misleading comment in byte view [\#6695](https://github.com/apache/arrow-rs/pull/6695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) +- minor fix on checking index [\#6694](https://github.com/apache/arrow-rs/pull/6694) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jp0317](https://github.com/jp0317)) +- Undo run end filter performance regression [\#6691](https://github.com/apache/arrow-rs/pull/6691) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- Reimplement `PartialEq` of `GenericByteViewArray` compares by logical value [\#6689](https://github.com/apache/arrow-rs/pull/6689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tlm365](https://github.com/tlm365)) +- feat: expose known\_schema from FlightDataEncoder [\#6688](https://github.com/apache/arrow-rs/pull/6688) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([nathanielc](https://github.com/nathanielc)) +- Update hashbrown requirement from 0.14.2 to 0.15.1 [\#6684](https://github.com/apache/arrow-rs/pull/6684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Duration in JSON Reader [\#6683](https://github.com/apache/arrow-rs/pull/6683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([simonvandel](https://github.com/simonvandel)) +- Check predicate and values are the same length for run end array filter safety [\#6675](https://github.com/apache/arrow-rs/pull/6675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- \[ffi\] Fix arrow-array null\_count error during conversion from C to Rust [\#6674](https://github.com/apache/arrow-rs/pull/6674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adbmal](https://github.com/adbmal)) +- Support `Utf8View` for `bit_length` kernel [\#6671](https://github.com/apache/arrow-rs/pull/6671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([austin362667](https://github.com/austin362667)) +- Fix string view LIKE checks with NULL values [\#6662](https://github.com/apache/arrow-rs/pull/6662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Improve documentation for `nullif` kernel [\#6658](https://github.com/apache/arrow-rs/pull/6658) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve test\_auth error message when contains\(\) fails [\#6657](https://github.com/apache/arrow-rs/pull/6657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- Let std::fmt::Debug for StructArray output Null/Validity info [\#6655](https://github.com/apache/arrow-rs/pull/6655) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XinyuZeng](https://github.com/XinyuZeng)) +- Include offending line number when processing CSV file fails [\#6653](https://github.com/apache/arrow-rs/pull/6653) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- feat: add write\_bytes for GenericBinaryBuilder [\#6652](https://github.com/apache/arrow-rs/pull/6652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tisonkun](https://github.com/tisonkun)) +- feat: Support Utf8View in JSON serialization [\#6651](https://github.com/apache/arrow-rs/pull/6651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonmmease](https://github.com/jonmmease)) +- fix: include chrono-tz in flight sql cli [\#6650](https://github.com/apache/arrow-rs/pull/6650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Handle primitive REPEATED field not contained in LIST annotated group [\#6649](https://github.com/apache/arrow-rs/pull/6649) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Implement `append_n` for `BooleanBuilder` [\#6646](https://github.com/apache/arrow-rs/pull/6646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- fix: recurse into Map datatype when hydrating dictionaries [\#6645](https://github.com/apache/arrow-rs/pull/6645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([nathanielc](https://github.com/nathanielc)) +- fix: enable TLS roots for flight CLI client [\#6640](https://github.com/apache/arrow-rs/pull/6640) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- doc: Clarify take kernel semantics [\#6632](https://github.com/apache/arrow-rs/pull/6632) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Return error rather than panic when too many row groups are written [\#6629](https://github.com/apache/arrow-rs/pull/6629) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix test feature selection so all feature combinations work as expected [\#6626](https://github.com/apache/arrow-rs/pull/6626) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([itsjunetime](https://github.com/itsjunetime)) +- Add Parquet RowSelection benchmark [\#6623](https://github.com/apache/arrow-rs/pull/6623) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Optimize `take_bits` to optimize `take_boolean` / `take_primitive` / `take_byte_view`: up to -25% [\#6622](https://github.com/apache/arrow-rs/pull/6622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make downcast macros hygenic \(\#6400\) [\#6620](https://github.com/apache/arrow-rs/pull/6620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.88 to =1.0.89 [\#6618](https://github.com/apache/arrow-rs/pull/6618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix arrow-json writer empty [\#6614](https://github.com/apache/arrow-rs/pull/6614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gwik](https://github.com/gwik)) +- Add `ParquetObjectReader::with_runtime` [\#6612](https://github.com/apache/arrow-rs/pull/6612) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([itsjunetime](https://github.com/itsjunetime)) +- Re-enable `C#` arrow flight integration test [\#6611](https://github.com/apache/arrow-rs/pull/6611) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +## [53.3.0](https://github.com/apache/arrow-rs/tree/53.3.0) (2024-11-17) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/53.2.0...53.3.0) + +**Implemented enhancements:** + +- `PartialEq` of GenericByteViewArray \(StringViewArray / ByteViewArray\) that compares on equality rather than logical value [\#6679](https://github.com/apache/arrow-rs/issues/6679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Need a mechanism to handle schema changes due to dictionary hydration in FlightSQL server implementations [\#6672](https://github.com/apache/arrow-rs/issues/6672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support encoding Utf8View columns to JSON [\#6642](https://github.com/apache/arrow-rs/issues/6642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `append_n` for `BooleanBuilder` [\#6634](https://github.com/apache/arrow-rs/issues/6634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Some take optimizations [\#6621](https://github.com/apache/arrow-rs/issues/6621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Error Instead of Panic On Attempting to Write More Than 32769 Row Groups [\#6591](https://github.com/apache/arrow-rs/issues/6591) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make casting from a timestamp without timezone to a timestamp with timezone configurable [\#6555](https://github.com/apache/arrow-rs/issues/6555) +- Add `record_batch!` macro for easy record batch creation [\#6553](https://github.com/apache/arrow-rs/issues/6553) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Binary` --\> `Utf8View` casting [\#6531](https://github.com/apache/arrow-rs/issues/6531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `downcast_primitive_array` and `downcast_dictionary_array` are not hygienic wrt imports [\#6400](https://github.com/apache/arrow-rs/issues/6400) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement interleave\_record\_batch [\#6731](https://github.com/apache/arrow-rs/pull/6731) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waynexia](https://github.com/waynexia)) +- feat: `record_batch!` macro [\#6588](https://github.com/apache/arrow-rs/pull/6588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ByteBaker](https://github.com/ByteBaker)) + +**Fixed bugs:** + +- Signed decimal e-notation parsing bug [\#6728](https://github.com/apache/arrow-rs/issues/6728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for Utf8View -\> numeric in can\_cast\_types [\#6715](https://github.com/apache/arrow-rs/issues/6715) +- IPC file writer produces incorrect footer when not preserving dict ID [\#6710](https://github.com/apache/arrow-rs/issues/6710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet from\_thrift\_helper incorrectly checks index [\#6693](https://github.com/apache/arrow-rs/issues/6693) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Primitive REPEATED fields not contained in LIST annotated groups aren't read as lists by record reader [\#6648](https://github.com/apache/arrow-rs/issues/6648) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- DictionaryHandling does not recurse into Map fields [\#6644](https://github.com/apache/arrow-rs/issues/6644) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Array writer output empty when no record is written [\#6613](https://github.com/apache/arrow-rs/issues/6613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Archery Integration Test with c\# failing on main [\#6577](https://github.com/apache/arrow-rs/issues/6577) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Potential unsoundness in `filter_run_end_array` [\#6569](https://github.com/apache/arrow-rs/issues/6569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet reader can generate incorrect validity buffer information for nested structures [\#6510](https://github.com/apache/arrow-rs/issues/6510) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- arrow-array ffi: FFI\_ArrowArray.null\_count is always interpreted as unsigned and initialized during conversion from C to Rust. [\#6497](https://github.com/apache/arrow-rs/issues/6497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Minor: Document pattern for accessing views in StringView [\#6673](https://github.com/apache/arrow-rs/pull/6673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve Array::is\_nullable documentation [\#6615](https://github.com/apache/arrow-rs/pull/6615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Minor: improve docs for ByteViewArray-\>ByteArray From impl [\#6610](https://github.com/apache/arrow-rs/pull/6610) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- Speed up `filter_run_end_array` [\#6712](https://github.com/apache/arrow-rs/pull/6712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) + +**Closed issues:** + +- Incorrect like results for pattern starting/ending with `%` percent and containing escape characters [\#6702](https://github.com/apache/arrow-rs/issues/6702) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix signed decimal e-notation parsing [\#6729](https://github.com/apache/arrow-rs/pull/6729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Clean up some arrow-flight tests and duplicated code [\#6725](https://github.com/apache/arrow-rs/pull/6725) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) +- Update PR template section about API breaking changes [\#6723](https://github.com/apache/arrow-rs/pull/6723) ([findepi](https://github.com/findepi)) +- Support for casting `StringViewArray` to `DecimalArray` [\#6720](https://github.com/apache/arrow-rs/pull/6720) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tlm365](https://github.com/tlm365)) +- File writer preserve dict bug [\#6711](https://github.com/apache/arrow-rs/pull/6711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Add filter\_kernel benchmark for run array [\#6706](https://github.com/apache/arrow-rs/pull/6706) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- Fix string view ILIKE checks with NULL values [\#6705](https://github.com/apache/arrow-rs/pull/6705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Implement logical\_null\_count for more array types [\#6704](https://github.com/apache/arrow-rs/pull/6704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Fix LIKE with escapes [\#6703](https://github.com/apache/arrow-rs/pull/6703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Speed up `filter_bytes` [\#6699](https://github.com/apache/arrow-rs/pull/6699) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Minor: fix misleading comment in byte view [\#6695](https://github.com/apache/arrow-rs/pull/6695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) +- minor fix on checking index [\#6694](https://github.com/apache/arrow-rs/pull/6694) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jp0317](https://github.com/jp0317)) +- Undo run end filter performance regression [\#6691](https://github.com/apache/arrow-rs/pull/6691) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- Reimplement `PartialEq` of `GenericByteViewArray` compares by logical value [\#6689](https://github.com/apache/arrow-rs/pull/6689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tlm365](https://github.com/tlm365)) +- feat: expose known\_schema from FlightDataEncoder [\#6688](https://github.com/apache/arrow-rs/pull/6688) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([nathanielc](https://github.com/nathanielc)) +- Update hashbrown requirement from 0.14.2 to 0.15.1 [\#6684](https://github.com/apache/arrow-rs/pull/6684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Duration in JSON Reader [\#6683](https://github.com/apache/arrow-rs/pull/6683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([simonvandel](https://github.com/simonvandel)) +- Check predicate and values are the same length for run end array filter safety [\#6675](https://github.com/apache/arrow-rs/pull/6675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- \[ffi\] Fix arrow-array null\_count error during conversion from C to Rust [\#6674](https://github.com/apache/arrow-rs/pull/6674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adbmal](https://github.com/adbmal)) +- Support `Utf8View` for `bit_length` kernel [\#6671](https://github.com/apache/arrow-rs/pull/6671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([austin362667](https://github.com/austin362667)) +- Fix string view LIKE checks with NULL values [\#6662](https://github.com/apache/arrow-rs/pull/6662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Improve documentation for `nullif` kernel [\#6658](https://github.com/apache/arrow-rs/pull/6658) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve test\_auth error message when contains\(\) fails [\#6657](https://github.com/apache/arrow-rs/pull/6657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- Let std::fmt::Debug for StructArray output Null/Validity info [\#6655](https://github.com/apache/arrow-rs/pull/6655) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XinyuZeng](https://github.com/XinyuZeng)) +- Include offending line number when processing CSV file fails [\#6653](https://github.com/apache/arrow-rs/pull/6653) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- feat: add write\_bytes for GenericBinaryBuilder [\#6652](https://github.com/apache/arrow-rs/pull/6652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tisonkun](https://github.com/tisonkun)) +- feat: Support Utf8View in JSON serialization [\#6651](https://github.com/apache/arrow-rs/pull/6651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonmmease](https://github.com/jonmmease)) +- fix: include chrono-tz in flight sql cli [\#6650](https://github.com/apache/arrow-rs/pull/6650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Handle primitive REPEATED field not contained in LIST annotated group [\#6649](https://github.com/apache/arrow-rs/pull/6649) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Implement `append_n` for `BooleanBuilder` [\#6646](https://github.com/apache/arrow-rs/pull/6646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([delamarch3](https://github.com/delamarch3)) +- fix: recurse into Map datatype when hydrating dictionaries [\#6645](https://github.com/apache/arrow-rs/pull/6645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([nathanielc](https://github.com/nathanielc)) +- fix: enable TLS roots for flight CLI client [\#6640](https://github.com/apache/arrow-rs/pull/6640) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- doc: Clarify take kernel semantics [\#6632](https://github.com/apache/arrow-rs/pull/6632) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Return error rather than panic when too many row groups are written [\#6629](https://github.com/apache/arrow-rs/pull/6629) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix test feature selection so all feature combinations work as expected [\#6626](https://github.com/apache/arrow-rs/pull/6626) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([itsjunetime](https://github.com/itsjunetime)) +- Add Parquet RowSelection benchmark [\#6623](https://github.com/apache/arrow-rs/pull/6623) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Optimize `take_bits` to optimize `take_boolean` / `take_primitive` / `take_byte_view`: up to -25% [\#6622](https://github.com/apache/arrow-rs/pull/6622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make downcast macros hygenic \(\#6400\) [\#6620](https://github.com/apache/arrow-rs/pull/6620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.88 to =1.0.89 [\#6618](https://github.com/apache/arrow-rs/pull/6618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix arrow-json writer empty [\#6614](https://github.com/apache/arrow-rs/pull/6614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gwik](https://github.com/gwik)) +- Add `ParquetObjectReader::with_runtime` [\#6612](https://github.com/apache/arrow-rs/pull/6612) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([itsjunetime](https://github.com/itsjunetime)) +- Re-enable `C#` arrow flight integration test [\#6611](https://github.com/apache/arrow-rs/pull/6611) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add Array::logical\_null\_count for inspecting number of null values [\#6608](https://github.com/apache/arrow-rs/pull/6608) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Added casting from Binary/LargeBinary to Utf8View [\#6592](https://github.com/apache/arrow-rs/pull/6592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ngli-me](https://github.com/ngli-me)) +- Parquet AsyncReader: Don't panic when empty offset\_index is Some\(\[\]\) [\#6582](https://github.com/apache/arrow-rs/pull/6582) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jroddev](https://github.com/jroddev)) +- Skip writing down null buffers for non-nullable primitive arrays [\#6524](https://github.com/apache/arrow-rs/pull/6524) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([bkirwi](https://github.com/bkirwi)) +## [53.2.0](https://github.com/apache/arrow-rs/tree/53.2.0) (2024-10-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/53.1.0...53.2.0) + +**Implemented enhancements:** + +- Implement arrow\_json encoder for Decimal128 & Decimal256 DataTypes [\#6605](https://github.com/apache/arrow-rs/issues/6605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support DataType::FixedSizeList in make\_builder within struct\_builder.rs [\#6594](https://github.com/apache/arrow-rs/issues/6594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support DataType::Dictionary in `make_builder` within struct\_builder.rs [\#6589](https://github.com/apache/arrow-rs/issues/6589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Interval parsing from string - accept "mon" and "mons" token [\#6548](https://github.com/apache/arrow-rs/issues/6548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `AsyncArrowWriter` API to get the total size of a written parquet file [\#6530](https://github.com/apache/arrow-rs/issues/6530) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `append_many` for Dictionary builders [\#6529](https://github.com/apache/arrow-rs/issues/6529) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Missing tonic `GRPC_STATUS` with tonic 0.12.1 [\#6515](https://github.com/apache/arrow-rs/issues/6515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add example of how to use parquet metadata reader APIs for a local cache [\#6504](https://github.com/apache/arrow-rs/issues/6504) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove reliance on `raw-entry` feature of Hashbrown [\#6498](https://github.com/apache/arrow-rs/issues/6498) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Improve page index metadata loading in `SerializedFileReader::new_with_options` [\#6491](https://github.com/apache/arrow-rs/issues/6491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Release arrow-rs / parquet minor version `53.1.0` \(October 2024\) [\#6340](https://github.com/apache/arrow-rs/issues/6340) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Compilation fail where `c_char = u8` [\#6571](https://github.com/apache/arrow-rs/issues/6571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Arrow flight CI test failing on `master` [\#6568](https://github.com/apache/arrow-rs/issues/6568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Documentation updates:** + +- Minor: Document SIMD rationale and tips [\#6554](https://github.com/apache/arrow-rs/pull/6554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Casting to and from unions [\#6247](https://github.com/apache/arrow-rs/issues/6247) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Minor: more comments for `RecordBatch.get_array_memory_size()` [\#6607](https://github.com/apache/arrow-rs/pull/6607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([2010YOUY01](https://github.com/2010YOUY01)) +- Implement arrow\_json encoder for Decimal128 & Decimal256 [\#6606](https://github.com/apache/arrow-rs/pull/6606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([phillipleblanc](https://github.com/phillipleblanc)) +- Add support for building FixedSizeListBuilder in struct\_builder's mak… [\#6595](https://github.com/apache/arrow-rs/pull/6595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszlim](https://github.com/kszlim)) +- Add limited support for dictionary builders in `make_builders` for stru… [\#6593](https://github.com/apache/arrow-rs/pull/6593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszlim](https://github.com/kszlim)) +- Fix CI with new valid certificates and add script for future usage [\#6585](https://github.com/apache/arrow-rs/pull/6585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) +- Update proc-macro2 requirement from =1.0.87 to =1.0.88 [\#6579](https://github.com/apache/arrow-rs/pull/6579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix clippy complaints [\#6573](https://github.com/apache/arrow-rs/pull/6573) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) +- Use c\_char instead of i8 to compile on platforms where c\_char = u8 [\#6572](https://github.com/apache/arrow-rs/pull/6572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([itsjunetime](https://github.com/itsjunetime)) +- Bump pyspark from 3.3.1 to 3.3.2 in /parquet/pytest [\#6564](https://github.com/apache/arrow-rs/pull/6564) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- `unsafe` improvements [\#6551](https://github.com/apache/arrow-rs/pull/6551) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ssbr](https://github.com/ssbr)) +- Update README.md [\#6550](https://github.com/apache/arrow-rs/pull/6550) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Abdullahsab3](https://github.com/Abdullahsab3)) +- Fix string '0' cast to decimal with scale 0 [\#6547](https://github.com/apache/arrow-rs/pull/6547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Add finish to `AsyncArrowWriter::finish` [\#6543](https://github.com/apache/arrow-rs/pull/6543) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add append\_nulls to dictionary builders [\#6542](https://github.com/apache/arrow-rs/pull/6542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) +- Improve UnionArray::is\_nullable [\#6540](https://github.com/apache/arrow-rs/pull/6540) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow to read parquet binary column as UTF8 type [\#6539](https://github.com/apache/arrow-rs/pull/6539) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([goldmedal](https://github.com/goldmedal)) +- Use HashTable instead of raw\_entry\_mut [\#6537](https://github.com/apache/arrow-rs/pull/6537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add append\_many to dictionary arrays to allow adding repeated values [\#6534](https://github.com/apache/arrow-rs/pull/6534) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) +- Adds documentation and example recommending Vec\ over ChunkedArray [\#6527](https://github.com/apache/arrow-rs/pull/6527) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([efredine](https://github.com/efredine)) +- Update proc-macro2 requirement from =1.0.86 to =1.0.87 [\#6526](https://github.com/apache/arrow-rs/pull/6526) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add `ColumnChunkMetadataBuilder` clear APIs [\#6523](https://github.com/apache/arrow-rs/pull/6523) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update sysinfo requirement from 0.31.2 to 0.32.0 [\#6521](https://github.com/apache/arrow-rs/pull/6521) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update Tonic to 0.12.3 [\#6517](https://github.com/apache/arrow-rs/pull/6517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([cisaacson](https://github.com/cisaacson)) +- Detect missing page indexes while reading Parquet metadata [\#6507](https://github.com/apache/arrow-rs/pull/6507) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Use ParquetMetaDataReader to load page indexes in `SerializedFileReader::new_with_options` [\#6506](https://github.com/apache/arrow-rs/pull/6506) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Improve parquet `MetadataFetch` and `AsyncFileReader` docs [\#6505](https://github.com/apache/arrow-rs/pull/6505) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- fix arrow-json encoding with dictionary including nulls [\#6503](https://github.com/apache/arrow-rs/pull/6503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- Update brotli requirement from 6.0 to 7.0 [\#6499](https://github.com/apache/arrow-rs/pull/6499) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Benchmark both scenarios, with records skipped and without skipping, for delta-bin-packed primitive arrays with half nulls. [\#6489](https://github.com/apache/arrow-rs/pull/6489) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([wiedld](https://github.com/wiedld)) +- Add round trip tests for reading/writing parquet metadata [\#6463](https://github.com/apache/arrow-rs/pull/6463) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) ## [53.1.0](https://github.com/apache/arrow-rs/tree/53.1.0) (2024-10-02) [Full Changelog](https://github.com/apache/arrow-rs/compare/53.0.0...53.1.0) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fdf9b6dd95c..a7f2a4ff34d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,69 +19,116 @@ # Changelog -## [53.2.0](https://github.com/apache/arrow-rs/tree/53.2.0) (2024-10-21) +## [54.0.0](https://github.com/apache/arrow-rs/tree/54.0.0) (2024-12-18) -[Full Changelog](https://github.com/apache/arrow-rs/compare/53.1.0...53.2.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/53.3.0...54.0.0) + +**Breaking changes:** + +- avoid redundant parsing of repeated value in RleDecoder [\#6834](https://github.com/apache/arrow-rs/pull/6834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jp0317](https://github.com/jp0317)) +- Handling nullable DictionaryArray in CSV parser [\#6830](https://github.com/apache/arrow-rs/pull/6830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([edmondop](https://github.com/edmondop)) +- fix\(flightsql\): remove Any encoding of DoPutUpdateResult [\#6825](https://github.com/apache/arrow-rs/pull/6825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([davisp](https://github.com/davisp)) +- arrow-ipc: Default to not preserving dict IDs [\#6788](https://github.com/apache/arrow-rs/pull/6788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Remove some very old deprecated functions [\#6774](https://github.com/apache/arrow-rs/pull/6774) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- update to pyo3 0.23.0 [\#6745](https://github.com/apache/arrow-rs/pull/6745) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Remove APIs deprecated since v 4.4.0 [\#6722](https://github.com/apache/arrow-rs/pull/6722) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- Return `None` when Parquet page indexes are not present in file [\#6639](https://github.com/apache/arrow-rs/pull/6639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add `ParquetError::NeedMoreData` mark `ParquetError` as `non_exhaustive` [\#6630](https://github.com/apache/arrow-rs/pull/6630) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Remove APIs deprecated since v 2.0.0 [\#6609](https://github.com/apache/arrow-rs/pull/6609) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) **Implemented enhancements:** -- Implement arrow\_json encoder for Decimal128 & Decimal256 DataTypes [\#6605](https://github.com/apache/arrow-rs/issues/6605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support DataType::FixedSizeList in make\_builder within struct\_builder.rs [\#6594](https://github.com/apache/arrow-rs/issues/6594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support DataType::Dictionary in `make_builder` within struct\_builder.rs [\#6589](https://github.com/apache/arrow-rs/issues/6589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Interval parsing from string - accept "mon" and "mons" token [\#6548](https://github.com/apache/arrow-rs/issues/6548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `AsyncArrowWriter` API to get the total size of a written parquet file [\#6530](https://github.com/apache/arrow-rs/issues/6530) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `append_many` for Dictionary builders [\#6529](https://github.com/apache/arrow-rs/issues/6529) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Missing tonic `GRPC_STATUS` with tonic 0.12.1 [\#6515](https://github.com/apache/arrow-rs/issues/6515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Add example of how to use parquet metadata reader APIs for a local cache [\#6504](https://github.com/apache/arrow-rs/issues/6504) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Remove reliance on `raw-entry` feature of Hashbrown [\#6498](https://github.com/apache/arrow-rs/issues/6498) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Improve page index metadata loading in `SerializedFileReader::new_with_options` [\#6491](https://github.com/apache/arrow-rs/issues/6491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Release arrow-rs / parquet minor version `53.1.0` \(October 2024\) [\#6340](https://github.com/apache/arrow-rs/issues/6340) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet schema hint doesn't support integer types upcasting [\#6891](https://github.com/apache/arrow-rs/issues/6891) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet UTF-8 max statistics are overly pessimistic [\#6867](https://github.com/apache/arrow-rs/issues/6867) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add builder support for Int8 keys [\#6844](https://github.com/apache/arrow-rs/issues/6844) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Formalize the name of the nested `Field` in a list [\#6784](https://github.com/apache/arrow-rs/issues/6784) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Allow disabling the writing of Parquet Offset Index [\#6778](https://github.com/apache/arrow-rs/issues/6778) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `parquet::record::make_row` is not exposed to users, leaving no option to users to manually create `Row` objects [\#6761](https://github.com/apache/arrow-rs/issues/6761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Avoid `from_num_days_from_ce_opt` calls in `timestamp_s_to_datetime` if we don't need [\#6746](https://github.com/apache/arrow-rs/issues/6746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Temporal -\> Utf8View casting [\#6734](https://github.com/apache/arrow-rs/issues/6734) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Option To Coerce List Type on Parquet Write [\#6733](https://github.com/apache/arrow-rs/issues/6733) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Numeric -\> Utf8View casting [\#6714](https://github.com/apache/arrow-rs/issues/6714) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Utf8View \<=\> boolean casting [\#6713](https://github.com/apache/arrow-rs/issues/6713) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Fixed bugs:** -- Compilation fail where `c_char = u8` [\#6571](https://github.com/apache/arrow-rs/issues/6571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Arrow flight CI test failing on `master` [\#6568](https://github.com/apache/arrow-rs/issues/6568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- `Buffer::bit_slice` loses length with byte-aligned offsets [\#6895](https://github.com/apache/arrow-rs/issues/6895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet arrow writer doesn't track memory size correctly for fixed sized lists [\#6839](https://github.com/apache/arrow-rs/issues/6839) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Casting Decimal128 to Decimal128 with smaller precision produces incorrect results in some cases [\#6833](https://github.com/apache/arrow-rs/issues/6833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Should empty nullable dictionary be parsed as null from arrow-csv? [\#6821](https://github.com/apache/arrow-rs/issues/6821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Array take doesn't make fields nullable [\#6809](https://github.com/apache/arrow-rs/issues/6809) +- Arrow Flight Encodes a Slice's List Offsets If the slice offset is starts with zero [\#6803](https://github.com/apache/arrow-rs/issues/6803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet readers incorrectly interpret legacy nested lists [\#6756](https://github.com/apache/arrow-rs/issues/6756) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- filter\_bits under-allocates resulting boolean buffer [\#6750](https://github.com/apache/arrow-rs/issues/6750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Multi-language support issues with Arrow FlightSQL client's execute\_update and execute\_ingest methods [\#6545](https://github.com/apache/arrow-rs/issues/6545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] **Documentation updates:** -- Minor: Document SIMD rationale and tips [\#6554](https://github.com/apache/arrow-rs/pull/6554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Should we document at what rate deprecated APIs are removed? [\#6851](https://github.com/apache/arrow-rs/issues/6851) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix docstring for `Format::with_header` in `arrow-csv` [\#6856](https://github.com/apache/arrow-rs/pull/6856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- Add deprecation / API removal policy [\#6852](https://github.com/apache/arrow-rs/pull/6852) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: add example for creating `SchemaDescriptor` [\#6841](https://github.com/apache/arrow-rs/pull/6841) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- chore: enrich panic context when BooleanBuffer fails to create [\#6810](https://github.com/apache/arrow-rs/pull/6810) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tisonkun](https://github.com/tisonkun)) **Closed issues:** -- Casting to and from unions [\#6247](https://github.com/apache/arrow-rs/issues/6247) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] GetCatalogsBuilder does not sort the catalog names [\#6807](https://github.com/apache/arrow-rs/issues/6807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add a lint to automatically check for unused dependencies [\#6796](https://github.com/apache/arrow-rs/issues/6796) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] **Merged pull requests:** -- Minor: more comments for `RecordBatch.get_array_memory_size()` [\#6607](https://github.com/apache/arrow-rs/pull/6607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([2010YOUY01](https://github.com/2010YOUY01)) -- Implement arrow\_json encoder for Decimal128 & Decimal256 [\#6606](https://github.com/apache/arrow-rs/pull/6606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([phillipleblanc](https://github.com/phillipleblanc)) -- Add support for building FixedSizeListBuilder in struct\_builder's mak… [\#6595](https://github.com/apache/arrow-rs/pull/6595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszlim](https://github.com/kszlim)) -- Add limited support for dictionary builders in `make_builders` for stru… [\#6593](https://github.com/apache/arrow-rs/pull/6593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszlim](https://github.com/kszlim)) -- Fix CI with new valid certificates and add script for future usage [\#6585](https://github.com/apache/arrow-rs/pull/6585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) -- Update proc-macro2 requirement from =1.0.87 to =1.0.88 [\#6579](https://github.com/apache/arrow-rs/pull/6579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Fix clippy complaints [\#6573](https://github.com/apache/arrow-rs/pull/6573) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) -- Use c\_char instead of i8 to compile on platforms where c\_char = u8 [\#6572](https://github.com/apache/arrow-rs/pull/6572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([itsjunetime](https://github.com/itsjunetime)) -- Bump pyspark from 3.3.1 to 3.3.2 in /parquet/pytest [\#6564](https://github.com/apache/arrow-rs/pull/6564) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- `unsafe` improvements [\#6551](https://github.com/apache/arrow-rs/pull/6551) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ssbr](https://github.com/ssbr)) -- Update README.md [\#6550](https://github.com/apache/arrow-rs/pull/6550) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Abdullahsab3](https://github.com/Abdullahsab3)) -- Fix string '0' cast to decimal with scale 0 [\#6547](https://github.com/apache/arrow-rs/pull/6547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) -- Add finish to `AsyncArrowWriter::finish` [\#6543](https://github.com/apache/arrow-rs/pull/6543) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Add append\_nulls to dictionary builders [\#6542](https://github.com/apache/arrow-rs/pull/6542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) -- Improve UnionArray::is\_nullable [\#6540](https://github.com/apache/arrow-rs/pull/6540) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Allow to read parquet binary column as UTF8 type [\#6539](https://github.com/apache/arrow-rs/pull/6539) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([goldmedal](https://github.com/goldmedal)) -- Use HashTable instead of raw\_entry\_mut [\#6537](https://github.com/apache/arrow-rs/pull/6537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Add append\_many to dictionary arrays to allow adding repeated values [\#6534](https://github.com/apache/arrow-rs/pull/6534) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) -- Adds documentation and example recommending Vec\ over ChunkedArray [\#6527](https://github.com/apache/arrow-rs/pull/6527) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([efredine](https://github.com/efredine)) -- Update proc-macro2 requirement from =1.0.86 to =1.0.87 [\#6526](https://github.com/apache/arrow-rs/pull/6526) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Add `ColumnChunkMetadataBuilder` clear APIs [\#6523](https://github.com/apache/arrow-rs/pull/6523) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Update sysinfo requirement from 0.31.2 to 0.32.0 [\#6521](https://github.com/apache/arrow-rs/pull/6521) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Update Tonic to 0.12.3 [\#6517](https://github.com/apache/arrow-rs/pull/6517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([cisaacson](https://github.com/cisaacson)) -- Detect missing page indexes while reading Parquet metadata [\#6507](https://github.com/apache/arrow-rs/pull/6507) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Use ParquetMetaDataReader to load page indexes in `SerializedFileReader::new_with_options` [\#6506](https://github.com/apache/arrow-rs/pull/6506) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Improve parquet `MetadataFetch` and `AsyncFileReader` docs [\#6505](https://github.com/apache/arrow-rs/pull/6505) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- fix arrow-json encoding with dictionary including nulls [\#6503](https://github.com/apache/arrow-rs/pull/6503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) -- Update brotli requirement from 6.0 to 7.0 [\#6499](https://github.com/apache/arrow-rs/pull/6499) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Benchmark both scenarios, with records skipped and without skipping, for delta-bin-packed primitive arrays with half nulls. [\#6489](https://github.com/apache/arrow-rs/pull/6489) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([wiedld](https://github.com/wiedld)) -- Add round trip tests for reading/writing parquet metadata [\#6463](https://github.com/apache/arrow-rs/pull/6463) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- doc: add comment for timezone string [\#6899](https://github.com/apache/arrow-rs/pull/6899) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) +- docs: fix typo [\#6890](https://github.com/apache/arrow-rs/pull/6890) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Minor: Fix deprecation notice for `arrow_to_parquet_schema` [\#6889](https://github.com/apache/arrow-rs/pull/6889) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add Field::with\_dict\_is\_ordered [\#6885](https://github.com/apache/arrow-rs/pull/6885) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Deprecate "max statistics size" property in `WriterProperties` [\#6884](https://github.com/apache/arrow-rs/pull/6884) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add deprecation warnings for everything related to `dict_id` [\#6873](https://github.com/apache/arrow-rs/pull/6873) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([brancz](https://github.com/brancz)) +- Enable matching temporal as from\_type to Utf8View [\#6872](https://github.com/apache/arrow-rs/pull/6872) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kev1n8](https://github.com/Kev1n8)) +- Enable string-based column projections from Parquet files [\#6871](https://github.com/apache/arrow-rs/pull/6871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Improvements to UTF-8 statistics truncation [\#6870](https://github.com/apache/arrow-rs/pull/6870) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- fix: make GetCatalogsBuilder sort catalog names [\#6864](https://github.com/apache/arrow-rs/pull/6864) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([niebayes](https://github.com/niebayes)) +- add buffered data\_pages to parquet column writer total bytes estimation [\#6862](https://github.com/apache/arrow-rs/pull/6862) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([onursatici](https://github.com/onursatici)) +- Update prost-build requirement from =0.13.3 to =0.13.4 [\#6860](https://github.com/apache/arrow-rs/pull/6860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Minor: add comments explaining bad MSRV, output in json [\#6857](https://github.com/apache/arrow-rs/pull/6857) ([alamb](https://github.com/alamb)) +- perf: Use Cow in get\_format\_string in FFI\_ArrowSchema [\#6853](https://github.com/apache/arrow-rs/pull/6853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- chore: add cast\_decimal benchmark [\#6850](https://github.com/apache/arrow-rs/pull/6850) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- arrow-array::builder: support Int8, Int16 and Int64 keys [\#6845](https://github.com/apache/arrow-rs/pull/6845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ajwerner](https://github.com/ajwerner)) +- Add `ArrowToParquetSchemaConverter`, deprecate `arrow_to_parquet_schema` [\#6840](https://github.com/apache/arrow-rs/pull/6840) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Remove APIs deprecated in 50.0.0 [\#6838](https://github.com/apache/arrow-rs/pull/6838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- fix: decimal conversion looses value on lower precision [\#6836](https://github.com/apache/arrow-rs/pull/6836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([himadripal](https://github.com/himadripal)) +- Update sysinfo requirement from 0.32.0 to 0.33.0 [\#6835](https://github.com/apache/arrow-rs/pull/6835) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Optionally coerce names of maps and lists to match Parquet specification [\#6828](https://github.com/apache/arrow-rs/pull/6828) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Remove deprecated unary\_dyn and try\_unary\_dyn [\#6824](https://github.com/apache/arrow-rs/pull/6824) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Remove deprecated flight\_data\_from\_arrow\_batch [\#6823](https://github.com/apache/arrow-rs/pull/6823) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- \[arrow-cast\] Support cast boolean from/to string view [\#6822](https://github.com/apache/arrow-rs/pull/6822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tlm365](https://github.com/tlm365)) +- Hook up Avro Decoder [\#6820](https://github.com/apache/arrow-rs/pull/6820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix arrow-avro compilation without default features [\#6819](https://github.com/apache/arrow-rs/pull/6819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Support shrink to empty [\#6817](https://github.com/apache/arrow-rs/pull/6817) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- \[arrow-cast\] Support cast numeric to string view \(alternate\) [\#6816](https://github.com/apache/arrow-rs/pull/6816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Hide implicit optional dependency features in arrow-flight [\#6806](https://github.com/apache/arrow-rs/pull/6806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- fix: Encoding of List offsets was incorrect when slice offsets begin with zero [\#6805](https://github.com/apache/arrow-rs/pull/6805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HawaiianSpork](https://github.com/HawaiianSpork)) +- Enable unused\_crate\_dependencies Rust lint, remove unused dependencies [\#6804](https://github.com/apache/arrow-rs/pull/6804) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- Minor: Fix docstrings for `ColumnProperties::statistics_enabled` property [\#6798](https://github.com/apache/arrow-rs/pull/6798) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add option to disable writing of Parquet offset index [\#6797](https://github.com/apache/arrow-rs/pull/6797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Remove unused dependencies [\#6792](https://github.com/apache/arrow-rs/pull/6792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([findepi](https://github.com/findepi)) +- Add `Array::shrink_to_fit(&mut self)` [\#6790](https://github.com/apache/arrow-rs/pull/6790) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- Formalize the default nested list field name to `item` [\#6785](https://github.com/apache/arrow-rs/pull/6785) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([gruuya](https://github.com/gruuya)) +- Improve UnionArray logical\_nulls tests [\#6781](https://github.com/apache/arrow-rs/pull/6781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gstvg](https://github.com/gstvg)) +- Improve list builder usage example in docs [\#6775](https://github.com/apache/arrow-rs/pull/6775) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Update proc-macro2 requirement from =1.0.89 to =1.0.92 [\#6772](https://github.com/apache/arrow-rs/pull/6772) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow NullBuffer construction directly from array [\#6769](https://github.com/apache/arrow-rs/pull/6769) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Include license and notice files in published crates [\#6767](https://github.com/apache/arrow-rs/pull/6767) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([ankane](https://github.com/ankane)) +- fix: remove redundant `bit_util::ceil` [\#6766](https://github.com/apache/arrow-rs/pull/6766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([miroim](https://github.com/miroim)) +- Remove 'make\_row', expose a 'Row::new' method instead. [\#6763](https://github.com/apache/arrow-rs/pull/6763) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jonded94](https://github.com/jonded94)) +- Read nested Parquet 2-level lists correctly [\#6757](https://github.com/apache/arrow-rs/pull/6757) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Split `timestamp_s_to_datetime` to `date` and `time` to avoid unnecessary computation [\#6755](https://github.com/apache/arrow-rs/pull/6755) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) +- More trivial implementation of `Box` and `Box` [\#6748](https://github.com/apache/arrow-rs/pull/6748) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ethe](https://github.com/ethe)) +- Update cache action to v4 [\#6744](https://github.com/apache/arrow-rs/pull/6744) ([findepi](https://github.com/findepi)) +- Remove redundant implementation of `StringArrayType` [\#6743](https://github.com/apache/arrow-rs/pull/6743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tlm365](https://github.com/tlm365)) +- Fix Dictionary logical nulls for RunArray/UnionArray Values [\#6740](https://github.com/apache/arrow-rs/pull/6740) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Allow reading Parquet maps that lack a `values` field [\#6730](https://github.com/apache/arrow-rs/pull/6730) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Improve default implementation of Array::is\_nullable [\#6721](https://github.com/apache/arrow-rs/pull/6721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Fix Buffer::bit\_slice losing length with byte-aligned offsets [\#6707](https://github.com/apache/arrow-rs/pull/6707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([itsjunetime](https://github.com/itsjunetime)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2dea0b2cca64..38236ee39125 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -138,7 +138,7 @@ cargo test cargo test -p arrow ``` -For some changes, you may want to run additional tests. You can find up-to-date information on the current CI tests in [.github/workflows](https://github.com/apache/arrow-rs/tree/master/.github/workflows). Here are some examples of additional tests you may want to run: +For some changes, you may want to run additional tests. You can find up-to-date information on the current CI tests in [.github/workflows](https://github.com/apache/arrow-rs/tree/main/.github/workflows). Here are some examples of additional tests you may want to run: ```bash # run tests for the parquet crate @@ -217,13 +217,13 @@ cargo bench -p arrow-cast --bench parse_time To set the baseline for your benchmarks, use the --save-baseline flag: ```bash -git checkout master +git checkout main -cargo bench --bench parse_time -- --save-baseline master +cargo bench --bench parse_time -- --save-baseline main git checkout feature -cargo bench --bench parse_time -- --baseline master +cargo bench --bench parse_time -- --baseline main ``` ## Git Pre-Commit Hook diff --git a/Cargo.toml b/Cargo.toml index f210ae210012..75ba410f12a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ exclude = [ ] [workspace.package] -version = "53.2.0" +version = "54.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] @@ -77,20 +77,20 @@ edition = "2021" rust-version = "1.62" [workspace.dependencies] -arrow = { version = "53.2.0", path = "./arrow", default-features = false } -arrow-arith = { version = "53.2.0", path = "./arrow-arith" } -arrow-array = { version = "53.2.0", path = "./arrow-array" } -arrow-buffer = { version = "53.2.0", path = "./arrow-buffer" } -arrow-cast = { version = "53.2.0", path = "./arrow-cast" } -arrow-csv = { version = "53.2.0", path = "./arrow-csv" } -arrow-data = { version = "53.2.0", path = "./arrow-data" } -arrow-ipc = { version = "53.2.0", path = "./arrow-ipc" } -arrow-json = { version = "53.2.0", path = "./arrow-json" } -arrow-ord = { version = "53.2.0", path = "./arrow-ord" } -arrow-row = { version = "53.2.0", path = "./arrow-row" } -arrow-schema = { version = "53.2.0", path = "./arrow-schema" } -arrow-select = { version = "53.2.0", path = "./arrow-select" } -arrow-string = { version = "53.2.0", path = "./arrow-string" } -parquet = { version = "53.2.0", path = "./parquet", default-features = false } +arrow = { version = "54.0.0", path = "./arrow", default-features = false } +arrow-arith = { version = "54.0.0", path = "./arrow-arith" } +arrow-array = { version = "54.0.0", path = "./arrow-array" } +arrow-buffer = { version = "54.0.0", path = "./arrow-buffer" } +arrow-cast = { version = "54.0.0", path = "./arrow-cast" } +arrow-csv = { version = "54.0.0", path = "./arrow-csv" } +arrow-data = { version = "54.0.0", path = "./arrow-data" } +arrow-ipc = { version = "54.0.0", path = "./arrow-ipc" } +arrow-json = { version = "54.0.0", path = "./arrow-json" } +arrow-ord = { version = "54.0.0", path = "./arrow-ord" } +arrow-row = { version = "54.0.0", path = "./arrow-row" } +arrow-schema = { version = "54.0.0", path = "./arrow-schema" } +arrow-select = { version = "54.0.0", path = "./arrow-select" } +arrow-string = { version = "54.0.0", path = "./arrow-string" } +parquet = { version = "54.0.0", path = "./parquet", default-features = false } chrono = { version = "0.4.34", default-features = false, features = ["clock"] } diff --git a/README.md b/README.md index 98c0a6615d9d..ed42f630514b 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,6 @@ # Native Rust implementation of Apache Arrow and Apache Parquet -[![Coverage Status](https://codecov.io/gh/apache/arrow-rs/rust/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/arrow-rs?branch=master) - Welcome to the [Rust][rust] implementation of [Apache Arrow], the popular in-memory columnar format. This repo contains the following main components: @@ -58,20 +56,21 @@ breaking API changes) at most once a quarter, and release incremental minor versions in the intervening months. See [this ticket] for more details. To keep our maintenance burden down, we do regularly scheduled releases (major -and minor) from the `master` branch. How we handle PRs with breaking API changes +and minor) from the `main` branch. How we handle PRs with breaking API changes is described in the [contributing] guide. [contributing]: CONTRIBUTING.md#breaking-changes Planned Release Schedule -| Approximate Date | Version | Notes | -| ---------------- | -------- | --------------------------------------- | -| Sep 2024 | `53.0.0` | Major, potentially breaking API changes | -| Oct 2024 | `53.1.0` | Minor, NO breaking API changes | -| Oct 2024 | `53.2.0` | Minor, NO breaking API changes | -| Nov 2024 | `53.3.0` | Minor, NO breaking API changes | -| Dec 2024 | `54.0.0` | Major, potentially breaking API changes | +| Approximate Date | Version | Notes | +| ---------------- | -------- | ------------------------------------------ | +| Nov 2024 | `53.3.0` | Minor, NO breaking API changes | +| Dec 2024 | `54.0.0` | Major, potentially breaking API changes | +| Jan 2025 | `53.4.0` | Minor, NO breaking API changes (`53` line) | +| Jan 2025 | `54.1.0` | Minor, NO breaking API changes | +| Feb 2025 | `54.2.0` | Minor, NO breaking API changes | +| Mar 2025 | `55.0.0` | Major, potentially breaking API changes | [this ticket]: https://github.com/apache/arrow-rs/issues/5368 [semantic versioning]: https://semver.org/ @@ -84,6 +83,40 @@ versions approximately every 2 months. [`object_store`]: https://crates.io/crates/object_store +Planned Release Schedule + +| Approximate Date | Version | Notes | +| ---------------- | -------- | --------------------------------------- | +| Dec 2024 | `0.11.2` | Minor, NO breaking API changes | +| Feb 2025 | `0.12.0` | Major, potentially breaking API changes | + +### Deprecation Guidelines + +Minor releases may deprecate, but not remove APIs. Deprecating APIs allows +downstream Rust programs to still compile, but generate compiler warnings. This +gives downstream crates time to migrate prior to API removal. + +To deprecate an API: + +- Mark the API as deprecated using `#[deprecated]` and specify the exact arrow-rs version in which it was deprecated +- Concisely describe the preferred API to help the user transition + +The deprecated version is the next version which will be released (please +consult the list above). To mark the API as deprecated, use the +`#[deprecated(since = "...", note = "...")]` attribute. + +Foe example + +```rust +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +``` + +In general, deprecated APIs will remain in the codebase for at least two major releases after +they were deprecated (typically between 6 - 9 months later). For example, an API +deprecated in `51.3.0` can be removed in `54.0.0` (or later). Deprecated APIs +may be removed earlier or later than these guidelines at the discretion of the +maintainers. + ## Related Projects There are several related crates in different repositories diff --git a/arrow-arith/Cargo.toml b/arrow-arith/Cargo.toml index d2ee0b9e2c72..66696df8aa04 100644 --- a/arrow-arith/Cargo.toml +++ b/arrow-arith/Cargo.toml @@ -39,7 +39,6 @@ arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } -half = { version = "2.1", default-features = false } num = { version = "0.4", default-features = false, features = ["std"] } [dev-dependencies] diff --git a/arrow-arith/LICENSE.txt b/arrow-arith/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-arith/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-arith/NOTICE.txt b/arrow-arith/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-arith/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index bb983e1225ac..9b3272abb617 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -18,14 +18,12 @@ //! Kernels for operating on [`PrimitiveArray`]s use arrow_array::builder::BufferBuilder; -use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::ArrowNativeType; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::ArrayData; use arrow_schema::ArrowError; -use std::sync::Arc; /// See [`PrimitiveArray::unary`] pub fn unary(array: &PrimitiveArray, op: F) -> PrimitiveArray @@ -71,97 +69,6 @@ where array.try_unary_mut(op) } -/// A helper function that applies an infallible unary function to a dictionary array with primitive value type. -fn unary_dict(array: &DictionaryArray, op: F) -> Result -where - K: ArrowDictionaryKeyType + ArrowNumericType, - T: ArrowPrimitiveType, - F: Fn(T::Native) -> T::Native, -{ - let dict_values = array.values().as_any().downcast_ref().unwrap(); - let values = unary::(dict_values, op); - Ok(Arc::new(array.with_values(Arc::new(values)))) -} - -/// A helper function that applies a fallible unary function to a dictionary array with primitive value type. -fn try_unary_dict(array: &DictionaryArray, op: F) -> Result -where - K: ArrowDictionaryKeyType + ArrowNumericType, - T: ArrowPrimitiveType, - F: Fn(T::Native) -> Result, -{ - if !PrimitiveArray::::is_compatible(&array.value_type()) { - return Err(ArrowError::CastError(format!( - "Cannot perform the unary operation of type {} on dictionary array of value type {}", - T::DATA_TYPE, - array.value_type() - ))); - } - - let dict_values = array.values().as_any().downcast_ref().unwrap(); - let values = try_unary::(dict_values, op)?; - Ok(Arc::new(array.with_values(Arc::new(values)))) -} - -/// Applies an infallible unary function to an array with primitive values. -#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] -pub fn unary_dyn(array: &dyn Array, op: F) -> Result -where - T: ArrowPrimitiveType, - F: Fn(T::Native) -> T::Native, -{ - downcast_dictionary_array! { - array => unary_dict::<_, F, T>(array, op), - t => { - if PrimitiveArray::::is_compatible(t) { - Ok(Arc::new(unary::( - array.as_any().downcast_ref::>().unwrap(), - op, - ))) - } else { - Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation of type {} on array of type {}", - T::DATA_TYPE, - t - ))) - } - } - } -} - -/// Applies a fallible unary function to an array with primitive values. -#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] -pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result -where - T: ArrowPrimitiveType, - F: Fn(T::Native) -> Result, -{ - downcast_dictionary_array! { - array => if array.values().data_type() == &T::DATA_TYPE { - try_unary_dict::<_, F, T>(array, op) - } else { - Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation on dictionary array of type {}", - array.data_type() - ))) - }, - t => { - if PrimitiveArray::::is_compatible(t) { - Ok(Arc::new(try_unary::( - array.as_any().downcast_ref::>().unwrap(), - op, - )?)) - } else { - Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation of type {} on array of type {}", - T::DATA_TYPE, - t - ))) - } - } - } -} - /// Allies a binary infallable function to two [`PrimitiveArray`]s, /// producing a new [`PrimitiveArray`] /// @@ -510,8 +417,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::*; use arrow_array::types::*; + use std::sync::Arc; #[test] #[allow(deprecated)] @@ -523,53 +430,6 @@ mod tests { result, Float64Array::from(vec![None, Some(7.0), None, Some(7.0)]) ); - - let result = unary_dyn::<_, Float64Type>(&input_slice, |n| n + 1.0).unwrap(); - - assert_eq!( - result.as_any().downcast_ref::().unwrap(), - &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)]) - ); - } - - #[test] - #[allow(deprecated)] - fn test_unary_dict_and_unary_dyn() { - let mut builder = PrimitiveDictionaryBuilder::::new(); - builder.append(5).unwrap(); - builder.append(6).unwrap(); - builder.append(7).unwrap(); - builder.append(8).unwrap(); - builder.append_null(); - builder.append(9).unwrap(); - let dictionary_array = builder.finish(); - - let mut builder = PrimitiveDictionaryBuilder::::new(); - builder.append(6).unwrap(); - builder.append(7).unwrap(); - builder.append(8).unwrap(); - builder.append(9).unwrap(); - builder.append_null(); - builder.append(10).unwrap(); - let expected = builder.finish(); - - let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); - assert_eq!( - result - .as_any() - .downcast_ref::>() - .unwrap(), - &expected - ); - - let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); - assert_eq!( - result - .as_any() - .downcast_ref::>() - .unwrap(), - &expected - ); } #[test] diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index 09d690d3237c..3458669a6fd1 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use arrow_array::cast::AsArray; use cast::as_primitive_array; -use chrono::{Datelike, NaiveDateTime, Offset, TimeZone, Timelike, Utc}; +use chrono::{Datelike, TimeZone, Timelike, Utc}; use arrow_array::temporal_conversions::{ date32_to_datetime, date64_to_datetime, timestamp_ms_to_datetime, timestamp_ns_to_datetime, @@ -82,6 +82,7 @@ impl std::fmt::Display for DatePart { /// Returns function to extract relevant [`DatePart`] from types like a /// [`NaiveDateTime`] or [`DateTime`]. /// +/// [`NaiveDateTime`]: chrono::NaiveDateTime /// [`DateTime`]: chrono::DateTime fn get_date_time_part_extract_fn(part: DatePart) -> fn(T) -> i32 where @@ -664,20 +665,6 @@ impl ChronoDateExt for T { } } -/// Parse the given string into a string representing fixed-offset that is correct as of the given -/// UTC NaiveDateTime. -/// -/// Note that the offset is function of time and can vary depending on whether daylight savings is -/// in effect or not. e.g. Australia/Sydney is +10:00 or +11:00 depending on DST. -#[deprecated(note = "Use arrow_array::timezone::Tz instead")] -pub fn using_chrono_tz_and_utc_naive_date_time( - tz: &str, - utc: NaiveDateTime, -) -> Option { - let tz: Tz = tz.parse().ok()?; - Some(tz.offset_from_utc_datetime(&utc).fix()) -} - /// Extracts the hours of a given array as an array of integers within /// the range of [0, 23]. If the given array isn't temporal primitive or dictionary array, /// an `Err` will be returned. diff --git a/arrow-array/LICENSE.txt b/arrow-array/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-array/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-array/NOTICE.txt b/arrow-array/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-array/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-array/benches/fixed_size_list_array.rs b/arrow-array/benches/fixed_size_list_array.rs index 5f001a4f3d3a..5270a4a5def3 100644 --- a/arrow-array/benches/fixed_size_list_array.rs +++ b/arrow-array/benches/fixed_size_list_array.rs @@ -26,7 +26,7 @@ fn gen_fsl(len: usize, value_len: usize) -> FixedSizeListArray { let values = Arc::new(Int32Array::from( (0..len).map(|_| rng.gen::()).collect::>(), )); - let field = Arc::new(Field::new("item", values.data_type().clone(), true)); + let field = Arc::new(Field::new_list_field(values.data_type().clone(), true)); FixedSizeListArray::new(field, value_len as i32, values, None) } diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs index 8f8a39b2093f..0e8a7a7cb618 100644 --- a/arrow-array/src/array/binary_array.rs +++ b/arrow-array/src/array/binary_array.rs @@ -24,12 +24,6 @@ use arrow_schema::DataType; pub type GenericBinaryArray = GenericByteArray>; impl GenericBinaryArray { - /// Get the data type of the array. - #[deprecated(note = "please use `Self::DATA_TYPE` instead")] - pub const fn get_data_type() -> DataType { - Self::DATA_TYPE - } - /// Creates a [GenericBinaryArray] from a vector of byte slices /// /// See also [`Self::from_iter_values`] @@ -358,7 +352,7 @@ mod tests { let values = b"helloparquet"; let child_data = ArrayData::builder(DataType::UInt8) .len(12) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .build() .unwrap(); let offsets = [0, 5, 5, 12].map(|n| O::from_usize(n).unwrap()); @@ -372,11 +366,9 @@ mod tests { .unwrap(); let binary_array1 = GenericBinaryArray::::from(array_data1); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( - "item", - DataType::UInt8, - false, - ))); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::UInt8, false), + )); let array_data2 = ArrayData::builder(data_type) .len(3) @@ -415,17 +407,15 @@ mod tests { let child_data = ArrayData::builder(DataType::UInt8) .len(15) .offset(5) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .build() .unwrap(); let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); let null_buffer = Buffer::from_slice_ref([0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( - "item", - DataType::UInt8, - false, - ))); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::UInt8, false), + )); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) @@ -460,17 +450,15 @@ mod tests { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .null_bit_buffer(Some(Buffer::from_slice_ref([0b1010101010]))) .build() .unwrap(); let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( - "item", - DataType::UInt8, - true, - ))); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::UInt8, true), + )); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) @@ -558,7 +546,7 @@ mod tests { .unwrap(); let offsets: [i32; 4] = [0, 5, 5, 12]; - let data_type = DataType::List(Arc::new(Field::new("item", DataType::UInt32, false))); + let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, false))); let array_data = ArrayData::builder(data_type) .len(3) .add_buffer(Buffer::from_slice_ref(offsets)) diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs index 0f95adacf10c..9c2d4af8c454 100644 --- a/arrow-array/src/array/boolean_array.rs +++ b/arrow-array/src/array/boolean_array.rs @@ -308,6 +308,13 @@ impl Array for BooleanArray { self.values.is_empty() } + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { self.values.offset() } diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index bec0caab1045..f2b22507081d 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -453,6 +453,14 @@ impl Array for GenericByteArray { self.value_offsets.len() <= 1 } + fn shrink_to_fit(&mut self) { + self.value_offsets.shrink_to_fit(); + self.value_data.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs index 81bb6a38550b..9d2d396a5266 100644 --- a/arrow-array/src/array/byte_view_array.rs +++ b/arrow-array/src/array/byte_view_array.rs @@ -430,31 +430,31 @@ impl GenericByteViewArray { /// /// Before GC: /// ```text - /// ┌──────┐ - /// │......│ - /// │......│ - /// ┌────────────────────┐ ┌ ─ ─ ─ ▶ │Data1 │ Large buffer + /// ┌──────┐ + /// │......│ + /// │......│ + /// ┌────────────────────┐ ┌ ─ ─ ─ ▶ │Data1 │ Large buffer /// │ View 1 │─ ─ ─ ─ │......│ with data that /// ├────────────────────┤ │......│ is not referred /// │ View 2 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data2 │ to by View 1 or - /// └────────────────────┘ │......│ View 2 - /// │......│ - /// 2 views, refer to │......│ - /// small portions of a └──────┘ - /// large buffer + /// └────────────────────┘ │......│ View 2 + /// │......│ + /// 2 views, refer to │......│ + /// small portions of a └──────┘ + /// large buffer /// ``` - /// + /// /// After GC: /// /// ```text /// ┌────────────────────┐ ┌─────┐ After gc, only - /// │ View 1 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data1│ data that is - /// ├────────────────────┤ ┌ ─ ─ ─ ▶ │Data2│ pointed to by - /// │ View 2 │─ ─ ─ ─ └─────┘ the views is - /// └────────────────────┘ left - /// - /// - /// 2 views + /// │ View 1 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data1│ data that is + /// ├────────────────────┤ ┌ ─ ─ ─ ▶ │Data2│ pointed to by + /// │ View 2 │─ ─ ─ ─ └─────┘ the views is + /// └────────────────────┘ left + /// + /// + /// 2 views /// ``` /// This method will compact the data buffers by recreating the view array and only include the data /// that is pointed to by the views. @@ -575,6 +575,15 @@ impl Array for GenericByteViewArray { self.views.is_empty() } + fn shrink_to_fit(&mut self) { + self.views.shrink_to_fit(); + self.buffers.iter_mut().for_each(|b| b.shrink_to_fit()); + self.buffers.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 1187e16769a0..f852b57fb65e 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -249,7 +249,7 @@ pub struct DictionaryArray { /// map to the real values. keys: PrimitiveArray, - /// Array of dictionary values (can by any DataType). + /// Array of dictionary values (can be any DataType). values: ArrayRef, /// Values are ordered. @@ -720,6 +720,11 @@ impl Array for DictionaryArray { self.keys.is_empty() } + fn shrink_to_fit(&mut self) { + self.keys.shrink_to_fit(); + self.values.shrink_to_fit(); + } + fn offset(&self) -> usize { self.keys.offset() } @@ -729,7 +734,7 @@ impl Array for DictionaryArray { } fn logical_nulls(&self) -> Option { - match self.values.nulls() { + match self.values.logical_nulls() { None => self.nulls().cloned(), Some(value_nulls) => { let mut builder = BooleanBufferBuilder::new(self.len()); @@ -749,6 +754,26 @@ impl Array for DictionaryArray { } } + fn logical_null_count(&self) -> usize { + match (self.keys.nulls(), self.values.logical_nulls()) { + (None, None) => 0, + (Some(key_nulls), None) => key_nulls.null_count(), + (None, Some(value_nulls)) => self + .keys + .values() + .iter() + .filter(|k| value_nulls.is_null(k.as_usize())) + .count(), + (Some(key_nulls), Some(value_nulls)) => self + .keys + .values() + .iter() + .enumerate() + .filter(|(idx, k)| key_nulls.is_null(*idx) || value_nulls.is_null(k.as_usize())) + .count(), + } + } + fn is_nullable(&self) -> bool { !self.is_empty() && (self.nulls().is_some() || self.values.is_nullable()) } @@ -1020,7 +1045,7 @@ impl AnyDictionaryArray for DictionaryArray { mod tests { use super::*; use crate::cast::as_dictionary_array; - use crate::{Int16Array, Int32Array, Int8Array}; + use crate::{Int16Array, Int32Array, Int8Array, RunArray}; use arrow_buffer::{Buffer, ToByteSlice}; #[test] @@ -1445,6 +1470,54 @@ mod tests { assert_eq!(values, &[Some(50), None, None, Some(2)]) } + #[test] + fn test_logical_nulls() -> Result<(), ArrowError> { + let values = Arc::new(RunArray::try_new( + &Int32Array::from(vec![1, 3, 7]), + &Int32Array::from(vec![Some(1), None, Some(3)]), + )?) as ArrayRef; + + // For this test to be meaningful, the values array need to have different nulls and logical nulls + assert_eq!(values.null_count(), 0); + assert_eq!(values.logical_null_count(), 2); + + // Construct a trivial dictionary with 1-1 mapping to underlying array + let dictionary = DictionaryArray::::try_new( + Int8Array::from((0..values.len()).map(|i| i as i8).collect::>()), + Arc::clone(&values), + )?; + + // No keys are null + assert_eq!(dictionary.null_count(), 0); + // Dictionary array values are logically nullable + assert_eq!(dictionary.logical_null_count(), values.logical_null_count()); + assert_eq!(dictionary.logical_nulls(), values.logical_nulls()); + assert!(dictionary.is_nullable()); + + // Construct a trivial dictionary with 1-1 mapping to underlying array except that key 0 is nulled out + let dictionary = DictionaryArray::::try_new( + Int8Array::from( + (0..values.len()) + .map(|i| i as i8) + .map(|i| if i == 0 { None } else { Some(i) }) + .collect::>(), + ), + Arc::clone(&values), + )?; + + // One key is null + assert_eq!(dictionary.null_count(), 1); + + // Dictionary array values are logically nullable + assert_eq!( + dictionary.logical_null_count(), + values.logical_null_count() + 1 + ); + assert!(dictionary.is_nullable()); + + Ok(()) + } + #[test] fn test_normalized_keys() { let values = vec![132, 0, 1].into(); diff --git a/arrow-array/src/array/fixed_size_binary_array.rs b/arrow-array/src/array/fixed_size_binary_array.rs index 8f1489ee4c3c..576b8012491b 100644 --- a/arrow-array/src/array/fixed_size_binary_array.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -237,6 +237,7 @@ impl FixedSizeBinaryArray { /// /// Returns error if argument has length zero, or sizes of nested slices don't match. #[deprecated( + since = "28.0.0", note = "This function will fail if the iterator produces only None values; prefer `try_from_sparse_iter_with_size`" )] pub fn try_from_sparse_iter(mut iter: T) -> Result @@ -602,6 +603,13 @@ impl Array for FixedSizeBinaryArray { self.len == 0 } + fn shrink_to_fit(&mut self) { + self.value_data.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } @@ -662,7 +670,7 @@ mod tests { let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) .len(3) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(&values)) .build() .unwrap(); let fixed_size_binary_array = FixedSizeBinaryArray::from(array_data); @@ -691,7 +699,7 @@ mod tests { let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) .len(2) .offset(1) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(&values)) .build() .unwrap(); let fixed_size_binary_array = FixedSizeBinaryArray::from(array_data); @@ -721,7 +729,7 @@ mod tests { // [null, [10, 11, 12, 13]] let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::UInt8, false)), + Arc::new(Field::new_list_field(DataType::UInt8, false)), 4, )) .len(2) @@ -757,7 +765,7 @@ mod tests { let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Binary, false)), + Arc::new(Field::new_list_field(DataType::Binary, false)), 4, )) .len(3) @@ -781,7 +789,7 @@ mod tests { let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::UInt8, false)), + Arc::new(Field::new_list_field(DataType::UInt8, false)), 4, )) .len(3) @@ -798,7 +806,7 @@ mod tests { let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) .len(3) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(&values)) .build() .unwrap(); let arr = FixedSizeBinaryArray::from(array_data); diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index 00a3144a87ad..44be442c9f85 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -95,7 +95,7 @@ use std::sync::Arc; /// .build() /// .unwrap(); /// let list_data_type = DataType::FixedSizeList( -/// Arc::new(Field::new("item", DataType::Int32, false)), +/// Arc::new(Field::new_list_field(DataType::Int32, false)), /// 3, /// ); /// let list_data = ArrayData::builder(list_data_type.clone()) @@ -401,6 +401,13 @@ impl Array for FixedSizeListArray { self.len == 0 } + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } @@ -487,7 +494,7 @@ mod tests { // Construct a list array from the above two let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3); let list_data = ArrayData::builder(list_data_type.clone()) .len(3) .add_child_data(value_data.clone()) @@ -540,7 +547,7 @@ mod tests { // Construct a list array from the above two let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -569,7 +576,7 @@ mod tests { // Construct a fixed size list array from the above two let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 2); let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data.clone()) @@ -627,7 +634,7 @@ mod tests { // Construct a fixed size list array from the above two let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 2); let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data) @@ -650,7 +657,7 @@ mod tests { Some(4), ])); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let list = FixedSizeListArray::new(field.clone(), 2, values.clone(), None); assert_eq!(list.len(), 3); @@ -674,7 +681,7 @@ mod tests { let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeListArray, expected 3 got 2"); - let field = Arc::new(Field::new("item", DataType::Int32, false)); + let field = Arc::new(Field::new_list_field(DataType::Int32, false)); let err = FixedSizeListArray::try_new(field.clone(), 2, values.clone(), None).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Found unmasked nulls for non-nullable FixedSizeListArray field \"item\""); @@ -682,14 +689,14 @@ mod tests { let nulls = NullBuffer::new(BooleanBuffer::new(Buffer::from([0b0000101]), 0, 3)); FixedSizeListArray::new(field, 2, values.clone(), Some(nulls)); - let field = Arc::new(Field::new("item", DataType::Int64, true)); + let field = Arc::new(Field::new_list_field(DataType::Int64, true)); let err = FixedSizeListArray::try_new(field, 2, values, None).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: FixedSizeListArray expected data type Int64 got Int32 for \"item\""); } #[test] fn empty_fixed_size_list() { - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let nulls = NullBuffer::new_null(2); let values = new_empty_array(&DataType::Int32); let list = FixedSizeListArray::new(field.clone(), 0, values, Some(nulls)); diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index 1fab0009f2cc..bed0bdf889b2 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -485,6 +485,14 @@ impl Array for GenericListArray { self.value_offsets.len() <= 1 } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.values.shrink_to_fit(); + self.value_offsets.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } @@ -565,7 +573,7 @@ mod tests { // [[0, 1, 2], [3, 4, 5], [6, 7]] let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 8])); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); ListArray::new(field, offsets, Arc::new(values), None) } @@ -595,7 +603,8 @@ mod tests { let value_offsets = Buffer::from([]); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(0) .add_buffer(value_offsets) @@ -621,7 +630,8 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type.clone()) .len(3) .add_buffer(value_offsets.clone()) @@ -766,7 +776,8 @@ mod tests { bit_util::set_bit(&mut null_bits, 8); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(9) .add_buffer(value_offsets) @@ -917,7 +928,8 @@ mod tests { .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) .build_unchecked() }; - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -934,7 +946,8 @@ mod tests { #[cfg(not(feature = "force_validate"))] fn test_list_array_invalid_child_array_len() { let value_offsets = Buffer::from_slice_ref([0, 2, 5, 7]); - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -964,7 +977,8 @@ mod tests { let value_offsets = Buffer::from_slice_ref([2, 2, 5, 7]); - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -1010,7 +1024,8 @@ mod tests { .build_unchecked() }; - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .add_buffer(buf2) diff --git a/arrow-array/src/array/list_view_array.rs b/arrow-array/src/array/list_view_array.rs index 4e949a642701..7e52a6f3e457 100644 --- a/arrow-array/src/array/list_view_array.rs +++ b/arrow-array/src/array/list_view_array.rs @@ -326,6 +326,15 @@ impl Array for GenericListViewArray { self.value_sizes.is_empty() } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.values.shrink_to_fit(); + self.value_offsets.shrink_to_fit(); + self.value_sizes.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } @@ -490,7 +499,7 @@ mod tests { fn test_empty_list_view_array() { // Construct an empty value array let vec: Vec = vec![]; - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![]); let offsets = ScalarBuffer::from(vec![]); let values = Int32Array::from(vec); @@ -508,7 +517,7 @@ mod tests { .build() .unwrap(); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![3i32, 3, 2]); let offsets = ScalarBuffer::from(vec![0i32, 3, 6]); let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); @@ -544,7 +553,7 @@ mod tests { .build() .unwrap(); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![3i64, 3, 2]); let offsets = ScalarBuffer::from(vec![0i64, 3, 6]); let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); @@ -590,7 +599,7 @@ mod tests { let buffer = BooleanBuffer::new(Buffer::from(null_bits), 0, 9); let null_buffer = NullBuffer::new(buffer); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![2, 0, 0, 2, 2, 0, 3, 0, 1]); let offsets = ScalarBuffer::from(vec![0, 2, 2, 2, 4, 6, 6, 9, 9]); let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); @@ -656,7 +665,7 @@ mod tests { let null_buffer = NullBuffer::new(buffer); // Construct a large list view array from the above two - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![2i64, 0, 0, 2, 2, 0, 3, 0, 1]); let offsets = ScalarBuffer::from(vec![0i64, 2, 2, 2, 4, 6, 6, 9, 9]); let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); @@ -718,7 +727,7 @@ mod tests { // Construct a buffer for value offsets, for the nested array: // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] // Construct a list array from the above two - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![2i32, 0, 0, 2, 2, 0, 3, 0, 1]); let offsets = ScalarBuffer::from(vec![0i32, 2, 2, 2, 4, 6, 6, 9, 9]); let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); @@ -741,7 +750,7 @@ mod tests { .build_unchecked() }; let list_data_type = - DataType::ListView(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -759,7 +768,7 @@ mod tests { fn test_list_view_array_invalid_child_array_len() { let value_offsets = Buffer::from_slice_ref([0, 2, 5, 7]); let list_data_type = - DataType::ListView(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -771,7 +780,7 @@ mod tests { #[test] fn test_list_view_array_offsets_need_not_start_at_zero() { - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let sizes = ScalarBuffer::from(vec![0i32, 0, 3]); let offsets = ScalarBuffer::from(vec![2i32, 2, 5]); let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); @@ -800,7 +809,7 @@ mod tests { }; let list_data_type = - DataType::ListView(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .add_buffer(offset_buf2) @@ -942,7 +951,7 @@ mod tests { .build_unchecked() }; let list_data_type = - DataType::ListView(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(2) @@ -976,7 +985,7 @@ mod tests { .build_unchecked() }; let list_data_type = - DataType::ListView(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -1015,7 +1024,7 @@ mod tests { .build_unchecked() }; let list_data_type = - DataType::ListView(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index 254437630a44..18a7c491aa16 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -372,6 +372,14 @@ impl Array for MapArray { self.value_offsets.len() <= 1 } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.entries.shrink_to_fit(); + self.value_offsets.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 4a9e54a60789..23b3cb628aaf 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -76,6 +76,8 @@ mod list_view_array; pub use list_view_array::*; +use crate::iterator::ArrayIter; + /// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) pub trait Array: std::fmt::Debug + Send + Sync { /// Returns the array as [`Any`] so that it can be @@ -165,6 +167,12 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// ``` fn is_empty(&self) -> bool; + /// Shrinks the capacity of any exclusively owned buffer as much as possible + /// + /// Shared or externally allocated buffers will be ignored, and + /// any buffer offsets will be preserved. + fn shrink_to_fit(&mut self) {} + /// Returns the offset into the underlying data used by this array(-slice). /// Note that the underlying data can be shared by many arrays. /// This defaults to `0`. @@ -315,8 +323,7 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// even if the nulls present in [`DictionaryArray::values`] are not referenced by any key, /// and therefore would not appear in [`Array::logical_nulls`]. fn is_nullable(&self) -> bool { - // TODO this is not necessarily perfect default implementation, since null_count() and logical_null_count() are not always equivalent - self.null_count() != 0 + self.logical_null_count() != 0 } /// Returns the total number of bytes of memory pointed to by this array. @@ -364,6 +371,15 @@ impl Array for ArrayRef { self.as_ref().is_empty() } + /// For shared buffers, this is a no-op. + fn shrink_to_fit(&mut self) { + if let Some(slf) = Arc::get_mut(self) { + slf.shrink_to_fit(); + } else { + // We ignore shared buffers. + } + } + fn offset(&self) -> usize { self.as_ref().offset() } @@ -570,6 +586,40 @@ pub trait ArrayAccessor: Array { unsafe fn value_unchecked(&self, index: usize) -> Self::Item; } +/// A trait for Arrow String Arrays, currently three types are supported: +/// - `StringArray` +/// - `LargeStringArray` +/// - `StringViewArray` +/// +/// This trait helps to abstract over the different types of string arrays +/// so that we don't need to duplicate the implementation for each type. +pub trait StringArrayType<'a>: ArrayAccessor + Sized { + /// Returns true if all data within this string array is ASCII + fn is_ascii(&self) -> bool; + + /// Constructs a new iterator + fn iter(&self) -> ArrayIter; +} + +impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { + fn is_ascii(&self) -> bool { + GenericStringArray::::is_ascii(self) + } + + fn iter(&self) -> ArrayIter { + GenericStringArray::::iter(self) + } +} +impl<'a> StringArrayType<'a> for &'a StringViewArray { + fn is_ascii(&self) -> bool { + StringViewArray::is_ascii(self) + } + + fn iter(&self) -> ArrayIter { + StringViewArray::iter(self) + } +} + impl PartialEq for dyn Array + '_ { fn eq(&self, other: &Self) -> bool { self.to_data().eq(&other.to_data()) @@ -876,7 +926,7 @@ mod tests { #[test] fn test_empty_list_primitive() { - let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let array = new_empty_array(&data_type); let a = array.as_any().downcast_ref::().unwrap(); assert_eq!(a.len(), 0); @@ -934,7 +984,7 @@ mod tests { #[test] fn test_null_list_primitive() { - let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let array = new_null_array(&data_type, 9); let a = array.as_any().downcast_ref::().unwrap(); assert_eq!(a.len(), 9); diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 7b0d6c5ca1b6..57aa23bf9040 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -1152,6 +1152,13 @@ impl Array for PrimitiveArray { self.values.is_empty() } + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } @@ -1480,24 +1487,6 @@ def_numeric_from_vec!(TimestampMicrosecondType); def_numeric_from_vec!(TimestampNanosecondType); impl PrimitiveArray { - /// Construct a timestamp array from a vec of i64 values and an optional timezone - #[deprecated(note = "Use with_timezone_opt instead")] - pub fn from_vec(data: Vec, timezone: Option) -> Self - where - Self: From>, - { - Self::from(data).with_timezone_opt(timezone) - } - - /// Construct a timestamp array from a vec of `Option` values and an optional timezone - #[deprecated(note = "Use with_timezone_opt instead")] - pub fn from_opt_vec(data: Vec>, timezone: Option) -> Self - where - Self: From>>, - { - Self::from(data).with_timezone_opt(timezone) - } - /// Returns the timezone of this array if any pub fn timezone(&self) -> Option<&str> { match self.data_type() { @@ -2296,7 +2285,7 @@ mod tests { ]; let array_data = ArrayData::builder(DataType::Decimal128(38, 6)) .len(2) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(&values)) .build() .unwrap(); let decimal_array = Decimal128Array::from(array_data); diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index dc4e6c96d9da..b340bf9a9065 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -330,6 +330,11 @@ impl Array for RunArray { self.run_ends.is_empty() } + fn shrink_to_fit(&mut self) { + self.run_ends.shrink_to_fit(); + self.values.shrink_to_fit(); + } + fn offset(&self) -> usize { self.run_ends.offset() } diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs index 25581cfaa49d..ed70e5744fff 100644 --- a/arrow-array/src/array/string_array.rs +++ b/arrow-array/src/array/string_array.rs @@ -17,18 +17,12 @@ use crate::types::GenericStringType; use crate::{GenericBinaryArray, GenericByteArray, GenericListArray, OffsetSizeTrait}; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::ArrowError; /// A [`GenericByteArray`] for storing `str` pub type GenericStringArray = GenericByteArray>; impl GenericStringArray { - /// Get the data type of the array. - #[deprecated(note = "please use `Self::DATA_TYPE` instead")] - pub const fn get_data_type() -> DataType { - Self::DATA_TYPE - } - /// Returns the number of `Unicode Scalar Value` in the string at index `i`. /// # Performance /// This function has `O(n)` time complexity where `n` is the string length. @@ -167,7 +161,7 @@ mod tests { use crate::Array; use arrow_buffer::Buffer; use arrow_data::ArrayData; - use arrow_schema::Field; + use arrow_schema::{DataType, Field}; use std::sync::Arc; #[test] @@ -382,17 +376,15 @@ mod tests { let child_data = ArrayData::builder(DataType::UInt8) .len(15) .offset(5) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .build() .unwrap(); let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); let null_buffer = Buffer::from_slice_ref([0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( - "item", - DataType::UInt8, - false, - ))); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::UInt8, false), + )); // [None, Some("Parquet")] let array_data = ArrayData::builder(data_type) @@ -427,7 +419,7 @@ mod tests { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .null_bit_buffer(Some(Buffer::from_slice_ref([0b1010101010]))) .build() .unwrap(); @@ -436,11 +428,9 @@ mod tests { // It is possible to create a null struct containing a non-nullable child // see https://github.com/apache/arrow-rs/pull/3244 for details - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( - "item", - DataType::UInt8, - true, - ))); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::UInt8, true), + )); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) @@ -469,16 +459,14 @@ mod tests { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt16) .len(5) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .build() .unwrap(); let offsets = [0, 2, 3].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( - "item", - DataType::UInt16, - false, - ))); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::UInt16, false), + )); let array_data = ArrayData::builder(data_type) .len(2) diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 41eb8235e540..de6d9c699d22 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -239,12 +239,6 @@ impl StructArray { &self.fields } - /// Returns child array refs of the struct array - #[deprecated(note = "Use columns().to_vec()")] - pub fn columns_ref(&self) -> Vec { - self.columns().to_vec() - } - /// Return field names in this struct array pub fn column_names(&self) -> Vec<&str> { match self.data_type() { @@ -370,6 +364,13 @@ impl Array for StructArray { self.len == 0 } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.fields.iter_mut().for_each(|n| n.shrink_to_fit()); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index 3c6da5a7b5c0..b442395b4978 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -653,6 +653,17 @@ impl UnionArray { } } } + + /// Returns a vector of tuples containing each field's type_id and its logical null buffer. + /// Only fields with non-zero null counts are included. + fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> { + self.fields + .iter() + .enumerate() + .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?))) + .filter(|(_, nulls)| nulls.null_count() > 0) + .collect() + } } impl From for UnionArray { @@ -744,6 +755,17 @@ impl Array for UnionArray { self.type_ids.is_empty() } + fn shrink_to_fit(&mut self) { + self.type_ids.shrink_to_fit(); + if let Some(offsets) = &mut self.offsets { + offsets.shrink_to_fit(); + } + for array in self.fields.iter_mut().flatten() { + array.shrink_to_fit(); + } + self.fields.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } @@ -768,11 +790,7 @@ impl Array for UnionArray { .flatten(); } - let logical_nulls = fields - .iter() - .filter_map(|(type_id, _)| Some((type_id, self.child(type_id).logical_nulls()?))) - .filter(|(_, nulls)| nulls.null_count() > 0) - .collect::>(); + let logical_nulls = self.fields_logical_nulls(); if logical_nulls.is_empty() { return None; @@ -1941,15 +1959,14 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap(); - let result = array.logical_nulls(); + let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]); - let expected = NullBuffer::from(vec![true, true, true, false, false, false]); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls())); } #[test] fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() { - // If we used union_fields() (3 fields with nulls), the choosen strategy would be Gather on x86 without any specified target feature e.g CI runtime let fields: UnionFields = [ (1, Arc::new(Field::new("A", DataType::Int32, true))), (3, Arc::new(Field::new("B", DataType::Float64, true))), @@ -1966,10 +1983,13 @@ mod tests { let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap(); - let result = array.logical_nulls(); + let expected = BooleanBuffer::from(vec![false, false, true, false]); - let expected = NullBuffer::from(vec![false, false, true, false]); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!( + expected, + array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls()) + ); //like above, but repeated to genereate two exact bitmasks and a non empty remainder let len = 2 * 64 + 32; @@ -1986,12 +2006,15 @@ mod tests { ) .unwrap(); - let result = array.logical_nulls(); - let expected = - NullBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len)); + BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len)); + assert_eq!(array.len(), len); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!( + expected, + array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls()) + ); } #[test] @@ -2010,10 +2033,13 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - let result = array.logical_nulls(); + let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]); - let expected = NullBuffer::from(vec![true, true, true, true, false, false]); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!( + expected, + array.mask_sparse_skip_without_nulls(array.fields_logical_nulls()) + ); //like above, but repeated to genereate two exact bitmasks and a non empty remainder let len = 2 * 64 + 32; @@ -2031,16 +2057,19 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - let result = array.logical_nulls(); - - let expected = NullBuffer::from_iter( + let expected = BooleanBuffer::from_iter( [true, true, true, true, false, true] .into_iter() .cycle() .take(len), ); + assert_eq!(array.len(), len); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!( + expected, + array.mask_sparse_skip_without_nulls(array.fields_logical_nulls()) + ); } #[test] @@ -2059,10 +2088,13 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - let result = array.logical_nulls(); + let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]); - let expected = NullBuffer::from(vec![false, false, true, true, false, false]); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!( + expected, + array.mask_sparse_skip_fully_null(array.fields_logical_nulls()) + ); //like above, but repeated to genereate two exact bitmasks and a non empty remainder let len = 2 * 64 + 32; @@ -2080,16 +2112,19 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - let result = array.logical_nulls(); - - let expected = NullBuffer::from_iter( + let expected = BooleanBuffer::from_iter( [false, false, true, true, false, false] .into_iter() .cycle() .take(len), ); + assert_eq!(array.len(), len); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!( + expected, + array.mask_sparse_skip_fully_null(array.fields_logical_nulls()) + ); } #[test] @@ -2125,11 +2160,10 @@ mod tests { ) .unwrap(); - let result = array.logical_nulls(); - - let expected = NullBuffer::from(vec![true, false, true, false]); + let expected = BooleanBuffer::from(vec![true, false, true, false]); - assert_eq!(Some(expected), result); + assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); + assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls())); } fn union_fields() -> UnionFields { diff --git a/arrow-array/src/builder/fixed_size_list_builder.rs b/arrow-array/src/builder/fixed_size_list_builder.rs index 5dff67650687..5c142b277d14 100644 --- a/arrow-array/src/builder/fixed_size_list_builder.rs +++ b/arrow-array/src/builder/fixed_size_list_builder.rs @@ -182,7 +182,7 @@ where let field = self .field .clone() - .unwrap_or_else(|| Arc::new(Field::new("item", values.data_type().clone(), true))); + .unwrap_or_else(|| Arc::new(Field::new_list_field(values.data_type().clone(), true))); FixedSizeListArray::new(field, self.list_len, values, nulls) } @@ -204,7 +204,7 @@ where let field = self .field .clone() - .unwrap_or_else(|| Arc::new(Field::new("item", values.data_type().clone(), true))); + .unwrap_or_else(|| Arc::new(Field::new_list_field(values.data_type().clone(), true))); FixedSizeListArray::new(field, self.list_len, values, nulls) } diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs index bb0fb5e91be2..ead151d5ceea 100644 --- a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -17,7 +17,7 @@ use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType}; -use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray}; +use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray, TypedDictionaryArray}; use arrow_buffer::ArrowNativeType; use arrow_schema::{ArrowError, DataType}; use hashbrown::HashTable; @@ -305,6 +305,63 @@ where }; } + /// Extends builder with an existing dictionary array. + /// + /// This is the same as [`Self::extend`] but is faster as it translates + /// the dictionary values once rather than doing a lookup for each item in the iterator + /// + /// when dictionary values are null (the actual mapped values) the keys are null + /// + pub fn extend_dictionary( + &mut self, + dictionary: &TypedDictionaryArray>, + ) -> Result<(), ArrowError> { + let values = dictionary.values(); + + let v_len = values.len(); + let k_len = dictionary.keys().len(); + if v_len == 0 && k_len == 0 { + return Ok(()); + } + + // All nulls + if v_len == 0 { + self.append_nulls(k_len); + return Ok(()); + } + + if k_len == 0 { + return Err(ArrowError::InvalidArgumentError( + "Dictionary keys should not be empty when values are not empty".to_string(), + )); + } + + // Orphan values will be carried over to the new dictionary + let mapped_values = values + .iter() + // Dictionary values can technically be null, so we need to handle that + .map(|dict_value| { + dict_value + .map(|dict_value| self.get_or_insert_key(dict_value)) + .transpose() + }) + .collect::, _>>()?; + + // Just insert the keys without additional lookups + dictionary.keys().iter().for_each(|key| match key { + None => self.append_null(), + Some(original_dict_index) => { + let index = original_dict_index.as_usize().min(v_len - 1); + match mapped_values[index] { + None => self.append_null(), + Some(mapped_value) => self.keys_builder.append_value(mapped_value), + } + } + }); + + Ok(()) + } + /// Builds the `DictionaryArray` and reset this builder. pub fn finish(&mut self) -> DictionaryArray { self.dedup.clear(); @@ -445,8 +502,9 @@ mod tests { use super::*; use crate::array::Int8Array; + use crate::cast::AsArray; use crate::types::{Int16Type, Int32Type, Int8Type, Utf8Type}; - use crate::{BinaryArray, StringArray}; + use crate::{ArrowPrimitiveType, BinaryArray, StringArray}; fn test_bytes_dictionary_builder(values: Vec<&T::Native>) where @@ -664,4 +722,129 @@ mod tests { assert_eq!(dict.keys().values(), &[0, 1, 2, 0, 1, 2, 2, 3, 0]); assert_eq!(dict.values().len(), 4); } + + #[test] + fn test_extend_dictionary() { + let some_dict = { + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.extend(["a", "b", "c", "a", "b", "c"].into_iter().map(Some)); + builder.extend([None::<&str>]); + builder.extend(["c", "d", "a"].into_iter().map(Some)); + builder.append_null(); + builder.finish() + }; + + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.extend(["e", "e", "f", "e", "d"].into_iter().map(Some)); + builder + .extend_dictionary(&some_dict.downcast_dict().unwrap()) + .unwrap(); + let dict = builder.finish(); + + assert_eq!(dict.values().len(), 6); + + let values = dict + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + + assert_eq!( + values, + [ + Some("e"), + Some("e"), + Some("f"), + Some("e"), + Some("d"), + Some("a"), + Some("b"), + Some("c"), + Some("a"), + Some("b"), + Some("c"), + None, + Some("c"), + Some("d"), + Some("a"), + None + ] + ); + } + #[test] + fn test_extend_dictionary_with_null_in_mapped_value() { + let some_dict = { + let mut values_builder = GenericByteBuilder::::new(); + let mut keys_builder = PrimitiveBuilder::::new(); + + // Manually build a dictionary values that the mapped values have null + values_builder.append_null(); + keys_builder.append_value(0); + values_builder.append_value("I like worm hugs"); + keys_builder.append_value(1); + + let values = values_builder.finish(); + let keys = keys_builder.finish(); + + let data_type = DataType::Dictionary( + Box::new(Int32Type::DATA_TYPE), + Box::new(Utf8Type::DATA_TYPE), + ); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + }; + + let some_dict_values = some_dict.values().as_string::(); + assert_eq!( + some_dict_values.into_iter().collect::>(), + &[None, Some("I like worm hugs")] + ); + + let mut builder = GenericByteDictionaryBuilder::::new(); + builder + .extend_dictionary(&some_dict.downcast_dict().unwrap()) + .unwrap(); + let dict = builder.finish(); + + assert_eq!(dict.values().len(), 1); + + let values = dict + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + + assert_eq!(values, [None, Some("I like worm hugs")]); + } + + #[test] + fn test_extend_all_null_dictionary() { + let some_dict = { + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.append_nulls(2); + builder.finish() + }; + + let mut builder = GenericByteDictionaryBuilder::::new(); + builder + .extend_dictionary(&some_dict.downcast_dict().unwrap()) + .unwrap(); + let dict = builder.finish(); + + assert_eq!(dict.values().len(), 0); + + let values = dict + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + + assert_eq!(values, [None, None]); + } } diff --git a/arrow-array/src/builder/generic_bytes_view_builder.rs b/arrow-array/src/builder/generic_bytes_view_builder.rs index d12c2b7db468..7268e751b149 100644 --- a/arrow-array/src/builder/generic_bytes_view_builder.rs +++ b/arrow-array/src/builder/generic_bytes_view_builder.rs @@ -136,7 +136,7 @@ impl GenericByteViewBuilder { /// Override the size of buffers to allocate for holding string data /// Use `with_fixed_block_size` instead. - #[deprecated(note = "Use `with_fixed_block_size` instead")] + #[deprecated(since = "53.0.0", note = "Use `with_fixed_block_size` instead")] pub fn with_block_size(self, block_size: u32) -> Self { self.with_fixed_block_size(block_size) } diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index a7d16f45f53b..a9c88ec6c586 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -49,7 +49,6 @@ use std::sync::Arc; /// builder.append(true); /// /// // Null -/// builder.values().append_value("?"); // irrelevant /// builder.append(false); /// /// // [D] @@ -70,15 +69,14 @@ use std::sync::Arc; /// array.values().as_ref(), /// &StringArray::from(vec![ /// Some("A"), Some("B"), Some("C"), -/// Some("?"), Some("D"), None, -/// Some("F") +/// Some("D"), None, Some("F") /// ]) /// ); /// /// // Offsets are indexes into the values array /// assert_eq!( /// array.value_offsets(), -/// &[0, 3, 3, 4, 5, 7] +/// &[0, 3, 3, 3, 4, 6] /// ); /// ``` /// @@ -299,7 +297,7 @@ where let field = match &self.field { Some(f) => f.clone(), - None => Arc::new(Field::new("item", values.data_type().clone(), true)), + None => Arc::new(Field::new_list_field(values.data_type().clone(), true)), }; GenericListArray::new(field, offsets, values, nulls) @@ -316,7 +314,7 @@ where let field = match &self.field { Some(f) => f.clone(), - None => Arc::new(Field::new("item", values.data_type().clone(), true)), + None => Arc::new(Field::new_list_field(values.data_type().clone(), true)), }; GenericListArray::new(field, offsets, values, nulls) @@ -586,7 +584,7 @@ mod tests { fn test_boxed_list_list_array_builder() { // This test is same as `test_list_list_array_builder` but uses boxed builders. let values_builder = make_builder( - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), 10, ); test_boxed_generic_list_generic_list_array_builder::(values_builder); @@ -596,7 +594,7 @@ mod tests { fn test_boxed_large_list_large_list_array_builder() { // This test is same as `test_list_list_array_builder` but uses boxed builders. let values_builder = make_builder( - &DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))), + &DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, true))), 10, ); test_boxed_generic_list_generic_list_array_builder::(values_builder); @@ -791,7 +789,7 @@ mod tests { #[test] #[should_panic(expected = "Non-nullable field of ListArray \\\"item\\\" cannot contain nulls")] fn test_checks_nullability() { - let field = Arc::new(Field::new("item", DataType::Int32, false)); + let field = Arc::new(Field::new_list_field(DataType::Int32, false)); let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); builder.append_value([Some(1), None]); builder.finish(); @@ -800,7 +798,7 @@ mod tests { #[test] #[should_panic(expected = "ListArray expected data type Int64 got Int32")] fn test_checks_data_type() { - let field = Arc::new(Field::new("item", DataType::Int64, false)); + let field = Arc::new(Field::new_list_field(DataType::Int64, false)); let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); builder.append_value([Some(1)]); builder.finish(); diff --git a/arrow-array/src/builder/generic_list_view_builder.rs b/arrow-array/src/builder/generic_list_view_builder.rs new file mode 100644 index 000000000000..5aaf9efefe24 --- /dev/null +++ b/arrow-array/src/builder/generic_list_view_builder.rs @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::ArrayBuilder; +use crate::{ArrayRef, GenericListViewArray, OffsetSizeTrait}; +use arrow_buffer::{Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer}; +use arrow_schema::{Field, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`GenericListViewArray`] +#[derive(Debug)] +pub struct GenericListViewBuilder { + offsets_builder: BufferBuilder, + sizes_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, + values_builder: T, + field: Option, + current_offset: OffsetSize, +} + +impl Default for GenericListViewBuilder { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl ArrayBuilder + for GenericListViewBuilder +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericListViewBuilder { + /// Creates a new [`GenericListViewBuilder`] from a given values array builder + pub fn new(values_builder: T) -> Self { + let capacity = values_builder.len(); + Self::with_capacity(values_builder, capacity) + } + + /// Creates a new [`GenericListViewBuilder`] from a given values array builder + /// `capacity` is the number of items to pre-allocate space for in this builder + pub fn with_capacity(values_builder: T, capacity: usize) -> Self { + let offsets_builder = BufferBuilder::::new(capacity); + let sizes_builder = BufferBuilder::::new(capacity); + Self { + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(capacity), + values_builder, + sizes_builder, + field: None, + current_offset: OffsetSize::zero(), + } + } + + /// + /// By default a nullable field is created with the name `item` + /// + /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the + /// field's data type does not match that of `T` + pub fn with_field(self, field: impl Into) -> Self { + Self { + field: Some(field.into()), + ..self + } + } +} + +impl GenericListViewBuilder +where + T: 'static, +{ + /// Returns the child array builder as a mutable reference. + /// + /// This mutable reference can be used to append values into the child array builder, + /// but you must call [`append`](#method.append) to delimit each distinct list value. + pub fn values(&mut self) -> &mut T { + &mut self.values_builder + } + + /// Returns the child array builder as an immutable reference + pub fn values_ref(&self) -> &T { + &self.values_builder + } + + /// Finish the current variable-length list array slot + /// + /// # Panics + /// + /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` + #[inline] + pub fn append(&mut self, is_valid: bool) { + self.offsets_builder.append(self.current_offset); + self.sizes_builder.append( + OffsetSize::from_usize( + self.values_builder.len() - self.current_offset.to_usize().unwrap(), + ) + .unwrap(), + ); + self.null_buffer_builder.append(is_valid); + self.current_offset = OffsetSize::from_usize(self.values_builder.len()).unwrap(); + } + + /// Append value into this [`GenericListViewBuilder`] + #[inline] + pub fn append_value(&mut self, i: I) + where + T: Extend>, + I: IntoIterator>, + { + self.extend(std::iter::once(Some(i))) + } + + /// Append a null to this [`GenericListViewBuilder`] + /// + /// See [`Self::append_value`] for an example use. + #[inline] + pub fn append_null(&mut self) { + self.offsets_builder.append(self.current_offset); + self.sizes_builder + .append(OffsetSize::from_usize(0).unwrap()); + self.null_buffer_builder.append_null(); + } + + /// Appends an optional value into this [`GenericListViewBuilder`] + /// + /// If `Some` calls [`Self::append_value`] otherwise calls [`Self::append_null`] + #[inline] + pub fn append_option(&mut self, i: Option) + where + T: Extend>, + I: IntoIterator>, + { + match i { + Some(i) => self.append_value(i), + None => self.append_null(), + } + } + + /// Builds the [`GenericListViewArray`] and reset this builder. + pub fn finish(&mut self) -> GenericListViewArray { + let values = self.values_builder.finish(); + let nulls = self.null_buffer_builder.finish(); + let offsets = self.offsets_builder.finish(); + self.current_offset = OffsetSize::zero(); + + // Safety: Safe by construction + let offsets = ScalarBuffer::from(offsets); + let sizes = self.sizes_builder.finish(); + let sizes = ScalarBuffer::from(sizes); + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; + GenericListViewArray::new(field, offsets, sizes, values, nulls) + } + + /// Builds the [`GenericListViewArray`] without resetting the builder. + pub fn finish_cloned(&self) -> GenericListViewArray { + let values = self.values_builder.finish_cloned(); + let nulls = self.null_buffer_builder.finish_cloned(); + + let offsets = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + // Safety: safe by construction + let offsets = ScalarBuffer::from(offsets); + + let sizes = Buffer::from_slice_ref(self.sizes_builder.as_slice()); + let sizes = ScalarBuffer::from(sizes); + + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; + + GenericListViewArray::new(field, offsets, sizes, values, nulls) + } + + /// Returns the current offsets buffer as a slice + pub fn offsets_slice(&self) -> &[OffsetSize] { + self.offsets_builder.as_slice() + } +} + +impl Extend> for GenericListViewBuilder +where + O: OffsetSizeTrait, + B: ArrayBuilder + Extend, + V: IntoIterator, +{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + match v { + Some(elements) => { + self.values_builder.extend(elements); + self.append(true); + } + None => self.append(false), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{make_builder, Int32Builder, ListViewBuilder}; + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::{Array, Int32Array}; + use arrow_schema::DataType; + + fn test_generic_list_view_array_builder_impl() { + let values_builder = Int32Builder::with_capacity(10); + let mut builder = GenericListViewBuilder::::new(values_builder); + + // [[0, 1, 2], [3, 4, 5], [6, 7]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_value(3); + builder.values().append_value(4); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + let list_array = builder.finish(); + + let list_values = list_array.values().as_primitive::(); + assert_eq!(list_values.values(), &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(list_array.value_offsets(), [0, 3, 6].map(O::usize_as)); + assert_eq!(list_array.value_sizes(), [3, 3, 2].map(O::usize_as)); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(O::from_usize(6).unwrap(), list_array.value_offsets()[2]); + assert_eq!(O::from_usize(2).unwrap(), list_array.value_sizes()[2]); + for i in 0..2 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + } + + #[test] + fn test_list_view_array_builder() { + test_generic_list_view_array_builder_impl::() + } + + #[test] + fn test_large_list_view_array_builder() { + test_generic_list_view_array_builder_impl::() + } + + fn test_generic_list_view_array_builder_nulls_impl() { + let values_builder = Int32Builder::with_capacity(10); + let mut builder = GenericListViewBuilder::::new(values_builder); + + // [[0, 1, 2], null, [3, null, 5], [6, 7]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(O::from_usize(3).unwrap(), list_array.value_offsets()[2]); + assert_eq!(O::from_usize(3).unwrap(), list_array.value_sizes()[2]); + } + + #[test] + fn test_list_view_array_builder_nulls() { + test_generic_list_view_array_builder_nulls_impl::() + } + + #[test] + fn test_large_list_view_array_builder_nulls() { + test_generic_list_view_array_builder_nulls_impl::() + } + + #[test] + fn test_list_view_array_builder_finish() { + let values_builder = Int32Array::builder(5); + let mut builder = ListViewBuilder::new(values_builder); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish(); + assert_eq!(2, arr.len()); + assert!(builder.is_empty()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(1, arr.len()); + assert!(builder.is_empty()); + } + + #[test] + fn test_list_view_array_builder_finish_cloned() { + let values_builder = Int32Array::builder(5); + let mut builder = ListViewBuilder::new(values_builder); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish_cloned(); + assert_eq!(2, arr.len()); + assert!(!builder.is_empty()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(3, arr.len()); + assert!(builder.is_empty()); + } + + #[test] + fn test_list_view_list_view_array_builder() { + let primitive_builder = Int32Builder::with_capacity(10); + let values_builder = ListViewBuilder::new(primitive_builder); + let mut builder = ListViewBuilder::new(values_builder); + + // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] + builder.values().values().append_value(1); + builder.values().values().append_value(2); + builder.values().append(true); + builder.values().values().append_value(3); + builder.values().values().append_value(4); + builder.values().append(true); + builder.append(true); + + builder.values().values().append_value(5); + builder.values().values().append_value(6); + builder.values().values().append_value(7); + builder.values().append(true); + builder.values().append(false); + builder.values().values().append_value(8); + builder.values().append(true); + builder.append(true); + + builder.append(false); + + builder.values().values().append_value(9); + builder.values().values().append_value(10); + builder.values().append(true); + builder.append(true); + + let l1 = builder.finish(); + + assert_eq!(4, l1.len()); + assert_eq!(1, l1.null_count()); + + assert_eq!(l1.value_offsets(), &[0, 2, 5, 5]); + assert_eq!(l1.value_sizes(), &[2, 3, 0, 1]); + + let l2 = l1.values().as_list_view::(); + + assert_eq!(6, l2.len()); + assert_eq!(1, l2.null_count()); + assert_eq!(l2.value_offsets(), &[0, 2, 4, 7, 7, 8]); + assert_eq!(l2.value_sizes(), &[2, 2, 3, 0, 1, 2]); + + let i1 = l2.values().as_primitive::(); + assert_eq!(10, i1.len()); + assert_eq!(0, i1.null_count()); + assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + fn test_extend() { + let mut builder = ListViewBuilder::new(Int32Builder::new()); + builder.extend([ + Some(vec![Some(1), Some(2), Some(7), None]), + Some(vec![]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let array = builder.finish(); + assert_eq!(array.value_offsets(), [0, 4, 4, 6]); + assert_eq!(array.value_sizes(), [4, 0, 2, 0]); + assert_eq!(array.null_count(), 1); + assert!(array.is_null(3)); + let elements = array.values().as_primitive::(); + assert_eq!(elements.values(), &[1, 2, 7, 0, 4, 5]); + assert_eq!(elements.null_count(), 1); + assert!(elements.is_null(3)); + } + + #[test] + fn test_boxed_primitive_array_builder() { + let values_builder = make_builder(&DataType::Int32, 5); + let mut builder = ListViewBuilder::new(values_builder); + + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_slice(&[1, 2, 3]); + builder.append(true); + + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_slice(&[4, 5, 6]); + builder.append(true); + + let arr = builder.finish(); + assert_eq!(2, arr.len()); + + let elements = arr.values().as_primitive::(); + assert_eq!(elements.values(), &[1, 2, 3, 4, 5, 6]); + } + + #[test] + fn test_boxed_list_view_list_view_array_builder() { + // This test is same as `test_list_list_array_builder` but uses boxed builders. + let values_builder = make_builder( + &DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))), + 10, + ); + test_boxed_generic_list_view_generic_list_view_array_builder::(values_builder); + } + + #[test] + fn test_boxed_large_list_view_large_list_view_array_builder() { + // This test is same as `test_list_list_array_builder` but uses boxed builders. + let values_builder = make_builder( + &DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))), + 10, + ); + test_boxed_generic_list_view_generic_list_view_array_builder::(values_builder); + } + + fn test_boxed_generic_list_view_generic_list_view_array_builder( + values_builder: Box, + ) where + O: OffsetSizeTrait + PartialEq, + { + let mut builder: GenericListViewBuilder> = + GenericListViewBuilder::>::new(values_builder); + + // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(1); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(2); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .append(true); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(3); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(4); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .append(true); + builder.append(true); + + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(5); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(6); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an (Large)ListViewBuilder") + .append_value(7); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .append(true); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .append(false); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(8); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .append(true); + builder.append(true); + + builder.append(false); + + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(9); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(10); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListViewBuilder") + .append(true); + builder.append(true); + + let l1 = builder.finish(); + assert_eq!(4, l1.len()); + assert_eq!(1, l1.null_count()); + assert_eq!(l1.value_offsets(), &[0, 2, 5, 5].map(O::usize_as)); + assert_eq!(l1.value_sizes(), &[2, 3, 0, 1].map(O::usize_as)); + + let l2 = l1.values().as_list_view::(); + assert_eq!(6, l2.len()); + assert_eq!(1, l2.null_count()); + assert_eq!(l2.value_offsets(), &[0, 2, 4, 7, 7, 8].map(O::usize_as)); + assert_eq!(l2.value_sizes(), &[2, 2, 3, 0, 1, 2].map(O::usize_as)); + + let i1 = l2.values().as_primitive::(); + assert_eq!(10, i1.len()); + assert_eq!(0, i1.null_count()); + assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + fn test_with_field() { + let field = Arc::new(Field::new("bar", DataType::Int32, false)); + let mut builder = ListViewBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), Some(2), Some(3)]); + builder.append_null(); // This is fine as nullability refers to nullability of values + builder.append_value([Some(4)]); + let array = builder.finish(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::ListView(field.clone())); + + builder.append_value([Some(4), Some(5)]); + let array = builder.finish(); + assert_eq!(array.data_type(), &DataType::ListView(field)); + assert_eq!(array.len(), 1); + } + + #[test] + #[should_panic( + expected = r#"Non-nullable field of ListViewArray \"item\" cannot contain nulls"# + )] + // If a non-nullable type is declared but a null value is used, it will be intercepted by the null check. + fn test_checks_nullability() { + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let mut builder = ListViewBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), None]); + builder.finish(); + } + + #[test] + #[should_panic(expected = "ListViewArray expected data type Int64 got Int32")] + // If the declared type does not match the actual appended type, it will be intercepted by type checking in the finish function. + fn test_checks_data_type() { + let field = Arc::new(Field::new("item", DataType::Int64, false)); + let mut builder = ListViewBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1)]); + builder.finish(); + } +} diff --git a/arrow-array/src/builder/mod.rs b/arrow-array/src/builder/mod.rs index dd1a5c3ae722..29d75024ea72 100644 --- a/arrow-array/src/builder/mod.rs +++ b/arrow-array/src/builder/mod.rs @@ -78,6 +78,73 @@ //! )) //! ``` //! +//! # Using the [`Extend`] trait to append values from an iterable: +//! +//! ``` +//! # use arrow_array::{Array}; +//! # use arrow_array::builder::{ArrayBuilder, StringBuilder}; +//! +//! let mut builder = StringBuilder::new(); +//! builder.extend(vec![Some("🍐"), Some("🍎"), None]); +//! assert_eq!(builder.finish().len(), 3); +//! ``` +//! +//! # Using the [`Extend`] trait to write generic functions: +//! +//! ``` +//! # use arrow_array::{Array, ArrayRef, StringArray}; +//! # use arrow_array::builder::{ArrayBuilder, Int32Builder, ListBuilder, StringBuilder}; +//! +//! // For generic methods that fill a list of values for an [`ArrayBuilder`], use the [`Extend`] trait. +//! fn filter_and_fill>(builder: &mut impl Extend, values: I, filter: V) +//! where V: PartialEq +//! { +//! builder.extend(values.into_iter().filter(|v| *v == filter)); +//! } +//! let mut string_builder = StringBuilder::new(); +//! filter_and_fill( +//! &mut string_builder, +//! vec![Some("🍐"), Some("🍎"), None], +//! Some("🍎"), +//! ); +//! assert_eq!(string_builder.finish().len(), 1); +//! +//! let mut int_builder = Int32Builder::new(); +//! filter_and_fill( +//! &mut int_builder, +//! vec![Some(11), Some(42), None], +//! Some(42), +//! ); +//! assert_eq!(int_builder.finish().len(), 1); +//! +//! // For generic methods that fill lists-of-lists for an [`ArrayBuilder`], use the [`Extend`] trait. +//! fn filter_and_fill_if_contains>>( +//! list_builder: &mut impl Extend>, +//! values: I, +//! filter: Option, +//! ) where +//! T: PartialEq, +//! for<'a> &'a V: IntoIterator>, +//! { +//! list_builder.extend(values.into_iter().filter(|string: &Option| { +//! string +//! .as_ref() +//! .map(|str: &V| str.into_iter().any(|ch: &Option| ch == &filter)) +//! .unwrap_or(false) +//! })); +//! } +//! let builder = StringBuilder::new(); +//! let mut list_builder = ListBuilder::new(builder); +//! let pear_pear = vec![Some("🍐"),Some("🍐")]; +//! let pear_app = vec![Some("🍐"),Some("🍎")]; +//! filter_and_fill_if_contains( +//! &mut list_builder, +//! vec![Some(pear_pear), Some(pear_app), None], +//! Some("🍎"), +//! ); +//! assert_eq!(list_builder.finish().len(), 1); +//! ``` +//! //! # Custom Builders //! //! It is common to have a collection of statically defined Rust types that @@ -123,7 +190,7 @@ //! let string_field = Arc::new(Field::new("i32", DataType::Utf8, false)); //! //! let i32_list = Arc::new(self.i32_list.finish()) as ArrayRef; -//! let value_field = Arc::new(Field::new("item", DataType::Int32, true)); +//! let value_field = Arc::new(Field::new_list_field(DataType::Int32, true)); //! let i32_list_field = Arc::new(Field::new("i32_list", DataType::List(value_field), true)); //! //! StructArray::from(vec![ @@ -134,6 +201,8 @@ //! } //! } //! +//! /// For building arrays in generic code, use Extend instead of the append_* methods +//! /// e.g. append_value, append_option, append_null //! impl<'a> Extend<&'a MyRow> for MyRowBuilder { //! fn extend>(&mut self, iter: T) { //! iter.into_iter().for_each(|row| self.append(row)); @@ -180,6 +249,8 @@ mod generic_byte_run_builder; pub use generic_byte_run_builder::*; mod generic_bytes_view_builder; pub use generic_bytes_view_builder::*; +mod generic_list_view_builder; +pub use generic_list_view_builder::*; mod union_builder; pub use union_builder::*; @@ -304,6 +375,12 @@ pub type ListBuilder = GenericListBuilder; /// Builder for [`LargeListArray`](crate::array::LargeListArray) pub type LargeListBuilder = GenericListBuilder; +/// Builder for [`ListViewArray`](crate::array::ListViewArray) +pub type ListViewBuilder = GenericListViewBuilder; + +/// Builder for [`LargeListViewArray`](crate::array::LargeListViewArray) +pub type LargeListViewBuilder = GenericListViewBuilder; + /// Builder for [`BinaryArray`](crate::array::BinaryArray) /// /// See examples on [`GenericBinaryBuilder`] diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs index ac40f8a469d3..282f0ae9d5b1 100644 --- a/arrow-array/src/builder/primitive_dictionary_builder.rs +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -17,7 +17,9 @@ use crate::builder::{ArrayBuilder, PrimitiveBuilder}; use crate::types::ArrowDictionaryKeyType; -use crate::{Array, ArrayRef, ArrowPrimitiveType, DictionaryArray}; +use crate::{ + Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, PrimitiveArray, TypedDictionaryArray, +}; use arrow_buffer::{ArrowNativeType, ToByteSlice}; use arrow_schema::{ArrowError, DataType}; use std::any::Any; @@ -44,7 +46,7 @@ impl PartialEq for Value { impl Eq for Value {} -/// Builder for [`DictionaryArray`] of [`PrimitiveArray`](crate::array::PrimitiveArray) +/// Builder for [`DictionaryArray`] of [`PrimitiveArray`] /// /// # Example: /// @@ -303,6 +305,63 @@ where }; } + /// Extends builder with dictionary + /// + /// This is the same as [`Self::extend`] but is faster as it translates + /// the dictionary values once rather than doing a lookup for each item in the iterator + /// + /// when dictionary values are null (the actual mapped values) the keys are null + /// + pub fn extend_dictionary( + &mut self, + dictionary: &TypedDictionaryArray>, + ) -> Result<(), ArrowError> { + let values = dictionary.values(); + + let v_len = values.len(); + let k_len = dictionary.keys().len(); + if v_len == 0 && k_len == 0 { + return Ok(()); + } + + // All nulls + if v_len == 0 { + self.append_nulls(k_len); + return Ok(()); + } + + if k_len == 0 { + return Err(ArrowError::InvalidArgumentError( + "Dictionary keys should not be empty when values are not empty".to_string(), + )); + } + + // Orphan values will be carried over to the new dictionary + let mapped_values = values + .iter() + // Dictionary values can technically be null, so we need to handle that + .map(|dict_value| { + dict_value + .map(|dict_value| self.get_or_insert_key(dict_value)) + .transpose() + }) + .collect::, _>>()?; + + // Just insert the keys without additional lookups + dictionary.keys().iter().for_each(|key| match key { + None => self.append_null(), + Some(original_dict_index) => { + let index = original_dict_index.as_usize().min(v_len - 1); + match mapped_values[index] { + None => self.append_null(), + Some(mapped_value) => self.keys_builder.append_value(mapped_value), + } + } + }); + + Ok(()) + } + /// Builds the `DictionaryArray` and reset this builder. pub fn finish(&mut self) -> DictionaryArray { self.map.clear(); @@ -368,9 +427,9 @@ impl Extend> mod tests { use super::*; - use crate::array::UInt32Array; - use crate::array::UInt8Array; + use crate::array::{Int32Array, UInt32Array, UInt8Array}; use crate::builder::Decimal128Builder; + use crate::cast::AsArray; use crate::types::{Decimal128Type, Int32Type, UInt32Type, UInt8Type}; #[test] @@ -443,4 +502,135 @@ mod tests { ) ); } + + #[test] + fn test_extend_dictionary() { + let some_dict = { + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some)); + builder.extend([None::]); + builder.extend([4, 5, 1, 3, 1].into_iter().map(Some)); + builder.append_null(); + builder.finish() + }; + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.extend([6, 6, 7, 6, 5].into_iter().map(Some)); + builder + .extend_dictionary(&some_dict.downcast_dict().unwrap()) + .unwrap(); + let dict = builder.finish(); + + assert_eq!(dict.values().len(), 7); + + let values = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + + assert_eq!( + values, + [ + Some(6), + Some(6), + Some(7), + Some(6), + Some(5), + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + Some(3), + None, + Some(4), + Some(5), + Some(1), + Some(3), + Some(1), + None + ] + ); + } + + #[test] + fn test_extend_dictionary_with_null_in_mapped_value() { + let some_dict = { + let mut values_builder = PrimitiveBuilder::::new(); + let mut keys_builder = PrimitiveBuilder::::new(); + + // Manually build a dictionary values that the mapped values have null + values_builder.append_null(); + keys_builder.append_value(0); + values_builder.append_value(42); + keys_builder.append_value(1); + + let values = values_builder.finish(); + let keys = keys_builder.finish(); + + let data_type = DataType::Dictionary( + Box::new(Int32Type::DATA_TYPE), + Box::new(values.data_type().clone()), + ); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + }; + + let some_dict_values = some_dict.values().as_primitive::(); + assert_eq!( + some_dict_values.into_iter().collect::>(), + &[None, Some(42)] + ); + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder + .extend_dictionary(&some_dict.downcast_dict().unwrap()) + .unwrap(); + let dict = builder.finish(); + + assert_eq!(dict.values().len(), 1); + + let values = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + + assert_eq!(values, [None, Some(42)]); + } + + #[test] + fn test_extend_all_null_dictionary() { + let some_dict = { + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append_nulls(2); + builder.finish() + }; + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder + .extend_dictionary(&some_dict.downcast_dict().unwrap()) + .unwrap(); + let dict = builder.finish(); + + assert_eq!(dict.values().len(), 0); + + let values = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + + assert_eq!(values, [None, None]); + } } diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index 396ab2fed851..4a40c2201746 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::*; -use crate::types::Int32Type; use crate::StructArray; +use crate::{ + builder::*, + types::{Int16Type, Int32Type, Int64Type, Int8Type}, +}; use arrow_buffer::NullBufferBuilder; use arrow_schema::{DataType, Fields, IntervalUnit, SchemaBuilder, TimeUnit}; use std::sync::Arc; @@ -46,8 +48,7 @@ use std::sync::Arc; /// let mut example_col = ListBuilder::new(StructBuilder::from_fields( /// vec![Field::new( /// "value_list", -/// DataType::List(Arc::new(Field::new( -/// "item", +/// DataType::List(Arc::new(Field::new_list_field( /// DataType::Struct(Fields::from(vec![ /// Field::new("key", DataType::Utf8, true), /// Field::new("value", DataType::Utf8, true), @@ -269,6 +270,16 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { + let builder = make_builder(field.data_type(), capacity); + Box::new(ListViewBuilder::with_capacity(builder, capacity).with_field(field.clone())) + } + DataType::LargeListView(field) => { + let builder = make_builder(field.data_type(), capacity); + Box::new( + LargeListViewBuilder::with_capacity(builder, capacity).with_field(field.clone()), + ) + } DataType::Map(field, _) => match field.data_type() { DataType::Struct(fields) => { let map_field_names = MapFieldNames { @@ -291,29 +302,42 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box panic!("The field of Map data type {t:?} should has a child Struct field"), }, DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), - DataType::Dictionary(key_type, value_type) if **key_type == DataType::Int32 => { - match &**value_type { - DataType::Utf8 => { - let dict_builder: StringDictionaryBuilder = - StringDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) - } - DataType::LargeUtf8 => { - let dict_builder: LargeStringDictionaryBuilder = - LargeStringDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) - } - DataType::Binary => { - let dict_builder: BinaryDictionaryBuilder = - BinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) - } - DataType::LargeBinary => { - let dict_builder: LargeBinaryDictionaryBuilder = - LargeBinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) + t @ DataType::Dictionary(key_type, value_type) => { + macro_rules! dict_builder { + ($key_type:ty) => { + match &**value_type { + DataType::Utf8 => { + let dict_builder: StringDictionaryBuilder<$key_type> = + StringDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + DataType::LargeUtf8 => { + let dict_builder: LargeStringDictionaryBuilder<$key_type> = + LargeStringDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + DataType::Binary => { + let dict_builder: BinaryDictionaryBuilder<$key_type> = + BinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + DataType::LargeBinary => { + let dict_builder: LargeBinaryDictionaryBuilder<$key_type> = + LargeBinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + t => panic!("Dictionary value type {t:?} is not currently supported"), + } + }; + } + match &**key_type { + DataType::Int8 => dict_builder!(Int8Type), + DataType::Int16 => dict_builder!(Int16Type), + DataType::Int32 => dict_builder!(Int32Type), + DataType::Int64 => dict_builder!(Int64Type), + _ => { + panic!("Data type {t:?} with key type {key_type:?} is not currently supported") } - t => panic!("Unsupported dictionary value type {t:?} is not currently supported"), } } t => panic!("Data type {t:?} is not currently supported"), @@ -431,12 +455,14 @@ impl StructBuilder { #[cfg(test)] mod tests { + use std::any::type_name; + use super::*; use arrow_buffer::Buffer; use arrow_data::ArrayData; use arrow_schema::Field; - use crate::array::Array; + use crate::{array::Array, types::ArrowDictionaryKeyType}; #[test] fn test_struct_array_builder() { @@ -691,10 +717,31 @@ mod tests { } #[test] - fn test_struct_array_builder_from_dictionary_type() { + fn test_struct_array_builder_from_dictionary_type_int8_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int8); + } + + #[test] + fn test_struct_array_builder_from_dictionary_type_int16_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int16); + } + + #[test] + fn test_struct_array_builder_from_dictionary_type_int32_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int32); + } + + #[test] + fn test_struct_array_builder_from_dictionary_type_int64_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int64); + } + + fn test_struct_array_builder_from_dictionary_type_inner( + key_type: DataType, + ) { let dict_field = Field::new( "f1", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(Box::new(key_type), Box::new(DataType::Utf8)), false, ); let fields = vec![dict_field.clone()]; @@ -702,10 +749,14 @@ mod tests { let cloned_dict_field = dict_field.clone(); let expected_child_dtype = dict_field.data_type(); let mut struct_builder = StructBuilder::from_fields(vec![cloned_dict_field], 5); - struct_builder - .field_builder::>(0) - .expect("Builder should be StringDictionaryBuilder") - .append_value("dict string"); + let Some(dict_builder) = struct_builder.field_builder::>(0) + else { + panic!( + "Builder should be StringDictionaryBuilder<{}>", + type_name::() + ) + }; + dict_builder.append_value("dict string"); struct_builder.append(true); let array = struct_builder.finish(); @@ -715,13 +766,15 @@ mod tests { } #[test] - #[should_panic(expected = "Data type Dictionary(Int16, Utf8) is not currently supported")] + #[should_panic( + expected = "Data type Dictionary(UInt64, Utf8) with key type UInt64 is not currently supported" + )] fn test_struct_array_builder_from_schema_unsupported_type() { let fields = vec![ - Field::new("f1", DataType::Int16, false), + Field::new("f1", DataType::UInt64, false), Field::new( "f2", - DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), false, ), ]; @@ -730,7 +783,7 @@ mod tests { } #[test] - #[should_panic(expected = "Unsupported dictionary value type Int32 is not currently supported")] + #[should_panic(expected = "Dictionary value type Int32 is not currently supported")] fn test_struct_array_builder_from_dict_with_unsupported_value_type() { let fields = vec![Field::new( "f1", diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 232b29560cbf..d871431593b6 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -689,12 +689,6 @@ array_downcast_fn!(as_struct_array, StructArray); array_downcast_fn!(as_union_array, UnionArray); array_downcast_fn!(as_map_array, MapArray); -/// Force downcast of an Array, such as an ArrayRef to Decimal128Array, panic’ing on failure. -#[deprecated(note = "please use `as_primitive_array::` instead")] -pub fn as_decimal_array(arr: &dyn Array) -> &PrimitiveArray { - as_primitive_array::(arr) -} - /// Downcasts a `dyn Array` to a concrete type /// /// ``` @@ -838,6 +832,14 @@ pub trait AsArray: private::Sealed { self.as_list_opt().expect("list array") } + /// Downcast this to a [`GenericListViewArray`] returning `None` if not possible + fn as_list_view_opt(&self) -> Option<&GenericListViewArray>; + + /// Downcast this to a [`GenericListViewArray`] panicking if not possible + fn as_list_view(&self) -> &GenericListViewArray { + self.as_list_view_opt().expect("list view array") + } + /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not possible fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>; @@ -911,6 +913,10 @@ impl AsArray for dyn Array + '_ { self.as_any().downcast_ref() } + fn as_list_view_opt(&self) -> Option<&GenericListViewArray> { + self.as_any().downcast_ref() + } + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { self.as_any().downcast_ref() } @@ -966,6 +972,10 @@ impl AsArray for ArrayRef { self.as_ref().as_list_opt() } + fn as_list_view_opt(&self) -> Option<&GenericListViewArray> { + self.as_ref().as_list_view_opt() + } + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { self.as_ref().as_fixed_size_binary_opt() } diff --git a/arrow-array/src/ffi.rs b/arrow-array/src/ffi.rs index 4426e0986409..144f2a21afec 100644 --- a/arrow-array/src/ffi.rs +++ b/arrow-array/src/ffi.rs @@ -121,7 +121,10 @@ type Result = std::result::Result; /// This function copies the content of two FFI structs [arrow_data::ffi::FFI_ArrowArray] and /// [arrow_schema::ffi::FFI_ArrowSchema] in the array to the location pointed by the raw pointers. /// Usually the raw pointers are provided by the array data consumer. -#[deprecated(note = "Use FFI_ArrowArray::new and FFI_ArrowSchema::try_from")] +#[deprecated( + since = "52.0.0", + note = "Use FFI_ArrowArray::new and FFI_ArrowSchema::try_from" +)] pub unsafe fn export_array_into_raw( src: ArrayRef, out_array: *mut FFI_ArrowArray, @@ -719,7 +722,7 @@ mod tests_to_then_from_ffi { // Construct a list array from the above two let list_data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::Int32, false), + Field::new_list_field(DataType::Int32, false), )); let list_data = ArrayData::builder(list_data_type) @@ -1478,7 +1481,7 @@ mod tests_from_ffi { let offsets: Vec = vec![0, 2, 4, 6, 8, 10, 12, 14, 16]; let value_offsets = Buffer::from_slice_ref(offsets); let inner_list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let inner_list_data = ArrayData::builder(inner_list_data_type.clone()) .len(8) .add_buffer(value_offsets) diff --git a/arrow-array/src/ffi_stream.rs b/arrow-array/src/ffi_stream.rs index 34f0cd7cfc74..3d4e89e80b89 100644 --- a/arrow-array/src/ffi_stream.rs +++ b/arrow-array/src/ffi_stream.rs @@ -379,21 +379,6 @@ impl RecordBatchReader for ArrowArrayStreamReader { } } -/// Exports a record batch reader to raw pointer of the C Stream Interface provided by the consumer. -/// -/// # Safety -/// Assumes that the pointer represents valid C Stream Interfaces, both in memory -/// representation and lifetime via the `release` mechanism. -#[deprecated(note = "Use FFI_ArrowArrayStream::new")] -pub unsafe fn export_reader_into_raw( - reader: Box, - out_stream: *mut FFI_ArrowArrayStream, -) { - let stream = FFI_ArrowArrayStream::new(reader); - - std::ptr::write_unaligned(out_stream, stream); -} - #[cfg(test)] mod tests { use super::*; diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 78108d441b05..8958ca6fae62 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -32,15 +32,6 @@ pub trait RecordBatchReader: Iterator> { /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this /// reader should have the same schema as returned from this method. fn schema(&self) -> SchemaRef; - - /// Reads the next `RecordBatch`. - #[deprecated( - since = "2.0.0", - note = "This method is deprecated in favour of `next` from the trait Iterator." - )] - fn next_batch(&mut self) -> Result, ArrowError> { - self.next().transpose() - } } impl RecordBatchReader for Box { @@ -58,6 +49,129 @@ pub trait RecordBatchWriter { fn close(self) -> Result<(), ArrowError>; } +/// Creates an array from a literal slice of values, +/// suitable for rapid testing and development. +/// +/// Example: +/// +/// ```rust +/// +/// use arrow_array::create_array; +/// +/// let array = create_array!(Int32, [1, 2, 3, 4, 5]); +/// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]); +/// ``` +/// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used. +/// Presently supported data types are: +/// - `Boolean`, `Null` +/// - `Decimal128`, `Decimal256` +/// - `Float16`, `Float32`, `Float64` +/// - `Int8`, `Int16`, `Int32`, `Int64` +/// - `UInt8`, `UInt16`, `UInt32`, `UInt64` +/// - `IntervalDayTime`, `IntervalYearMonth` +/// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond` +/// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64` +/// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond` +/// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond` +/// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary` +#[macro_export] +macro_rules! create_array { + // `@from` is used for those types that have a common method `::from` + (@from Boolean) => { $crate::BooleanArray }; + (@from Int8) => { $crate::Int8Array }; + (@from Int16) => { $crate::Int16Array }; + (@from Int32) => { $crate::Int32Array }; + (@from Int64) => { $crate::Int64Array }; + (@from UInt8) => { $crate::UInt8Array }; + (@from UInt16) => { $crate::UInt16Array }; + (@from UInt32) => { $crate::UInt32Array }; + (@from UInt64) => { $crate::UInt64Array }; + (@from Float16) => { $crate::Float16Array }; + (@from Float32) => { $crate::Float32Array }; + (@from Float64) => { $crate::Float64Array }; + (@from Utf8) => { $crate::StringArray }; + (@from Utf8View) => { $crate::StringViewArray }; + (@from LargeUtf8) => { $crate::LargeStringArray }; + (@from IntervalDayTime) => { $crate::IntervalDayTimeArray }; + (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray }; + (@from Second) => { $crate::TimestampSecondArray }; + (@from Millisecond) => { $crate::TimestampMillisecondArray }; + (@from Microsecond) => { $crate::TimestampMicrosecondArray }; + (@from Nanosecond) => { $crate::TimestampNanosecondArray }; + (@from Second32) => { $crate::Time32SecondArray }; + (@from Millisecond32) => { $crate::Time32MillisecondArray }; + (@from Microsecond64) => { $crate::Time64MicrosecondArray }; + (@from Nanosecond64) => { $crate::Time64Nanosecond64Array }; + (@from DurationSecond) => { $crate::DurationSecondArray }; + (@from DurationMillisecond) => { $crate::DurationMillisecondArray }; + (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray }; + (@from DurationNanosecond) => { $crate::DurationNanosecondArray }; + (@from Decimal128) => { $crate::Decimal128Array }; + (@from Decimal256) => { $crate::Decimal256Array }; + (@from TimestampSecond) => { $crate::TimestampSecondArray }; + (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray }; + (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray }; + (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray }; + + (@from $ty: ident) => { + compile_error!(concat!("Unsupported data type: ", stringify!($ty))) + }; + + (Null, $size: expr) => { + std::sync::Arc::new($crate::NullArray::new($size)) + }; + + (Binary, [$($values: expr),*]) => { + std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*])) + }; + + (LargeBinary, [$($values: expr),*]) => { + std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*])) + }; + + ($ty: tt, [$($values: expr),*]) => { + std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*])) + }; +} + +/// Creates a record batch from literal slice of values, suitable for rapid +/// testing and development. +/// +/// Example: +/// +/// ```rust +/// use arrow_array::record_batch; +/// use arrow_schema; +/// +/// let batch = record_batch!( +/// ("a", Int32, [1, 2, 3]), +/// ("b", Float64, [Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, ["alpha", "beta", "gamma"]) +/// ); +/// ``` +/// Due to limitation of [`create_array!`] macro, support for limited data types is available. +#[macro_export] +macro_rules! record_batch { + ($(($name: expr, $type: ident, [$($values: expr),*])),*) => { + { + let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + $( + arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + )* + ])); + + let batch = $crate::RecordBatch::try_new( + schema, + vec![$( + $crate::create_array!($type, [$($values),*]), + )*] + ); + + batch + } + } +} + /// A two-dimensional batch of column-oriented data with a defined /// [schema](arrow_schema::Schema). /// @@ -68,6 +182,19 @@ pub trait RecordBatchWriter { /// /// Record batches are a convenient unit of work for various /// serialization and computation functions, possibly incremental. +/// +/// Use the [`record_batch!`] macro to create a [`RecordBatch`] from +/// literal slice of values, useful for rapid prototyping and testing. +/// +/// Example: +/// ```rust +/// use arrow_array::record_batch; +/// let batch = record_batch!( +/// ("a", Int32, [1, 2, 3]), +/// ("b", Float64, [Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, ["alpha", "beta", "gamma"]) +/// ); +/// ``` #[derive(Clone, Debug, PartialEq)] pub struct RecordBatch { schema: SchemaRef, @@ -411,6 +538,19 @@ impl RecordBatch { /// ("b", b), /// ]); /// ``` + /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro, + /// which is particularly helpful for rapid prototyping and testing. + /// + /// Example: + /// + /// ```rust + /// use arrow_array::record_batch; + /// let batch = record_batch!( + /// ("a", Int32, [1, 2, 3]), + /// ("b", Float64, [Some(4.0), None, Some(5.0)]), + /// ("c", Utf8, ["alpha", "beta", "gamma"]) + /// ); + /// ``` pub fn try_from_iter(value: I) -> Result where I: IntoIterator, @@ -806,7 +946,7 @@ mod tests { fn create_record_batch_field_name_mismatch() { let fields = vec![ Field::new("a1", DataType::Int32, false), - Field::new_list("a2", Field::new("item", DataType::Int8, false), false), + Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false), ]; let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)])); diff --git a/arrow-array/src/temporal_conversions.rs b/arrow-array/src/temporal_conversions.rs index 8d238b3a196c..23f950d55048 100644 --- a/arrow-array/src/temporal_conversions.rs +++ b/arrow-array/src/temporal_conversions.rs @@ -37,8 +37,18 @@ pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; pub const MICROSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MICROSECONDS; /// Number of nanoseconds in a day pub const NANOSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * NANOSECONDS; -/// Number of days between 0001-01-01 and 1970-01-01 -pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// Constant from chrono crate +/// +/// Number of days between Januari 1, 1970 and December 31, 1 BCE which we define to be day 0. +/// 4 full leap year cycles until December 31, 1600 4 * 146097 = 584388 +/// 1 day until January 1, 1601 1 +/// 369 years until Januari 1, 1970 369 * 365 = 134685 +/// of which floor(369 / 4) are leap years floor(369 / 4) = 92 +/// except for 1700, 1800 and 1900 -3 + +/// -------- +/// 719163 +pub const UNIX_EPOCH_DAY: i64 = 719_163; /// converts a `i32` representing a `date32` to [`NaiveDateTime`] #[inline] @@ -134,6 +144,31 @@ pub fn timestamp_s_to_datetime(v: i64) -> Option { Some(DateTime::from_timestamp(v, 0)?.naive_utc()) } +/// Similar to timestamp_s_to_datetime but only compute `date` +#[inline] +pub fn timestamp_s_to_date(secs: i64) -> Option { + let days = secs.div_euclid(86_400) + UNIX_EPOCH_DAY; + if days < i32::MIN as i64 || days > i32::MAX as i64 { + return None; + } + let date = NaiveDate::from_num_days_from_ce_opt(days as i32)?; + Some(date.and_time(NaiveTime::default()).and_utc().naive_utc()) +} + +/// Similar to timestamp_s_to_datetime but only compute `time` +#[inline] +pub fn timestamp_s_to_time(secs: i64) -> Option { + let secs = secs.rem_euclid(86_400); + let time = NaiveTime::from_num_seconds_from_midnight_opt(secs as u32, 0)?; + Some( + DateTime::::from_naive_utc_and_offset( + NaiveDateTime::new(NaiveDate::default(), time), + Utc, + ) + .naive_utc(), + ) +} + /// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] #[inline] pub fn timestamp_ms_to_datetime(v: i64) -> Option { @@ -274,10 +309,28 @@ pub fn as_duration(v: i64) -> Option { mod tests { use crate::temporal_conversions::{ date64_to_datetime, split_second, timestamp_ms_to_datetime, timestamp_ns_to_datetime, + timestamp_s_to_date, timestamp_s_to_datetime, timestamp_s_to_time, timestamp_us_to_datetime, NANOSECONDS, }; use chrono::DateTime; + #[test] + fn test_timestamp_func() { + let timestamp = 1234; + let datetime = timestamp_s_to_datetime(timestamp).unwrap(); + let expected_date = datetime.date(); + let expected_time = datetime.time(); + + assert_eq!( + timestamp_s_to_date(timestamp).unwrap().date(), + expected_date + ); + assert_eq!( + timestamp_s_to_time(timestamp).unwrap().time(), + expected_time + ); + } + #[test] fn negative_input_timestamp_ns_to_datetime() { assert_eq!( diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index 92262fc04a57..3d8cfcdb112b 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -69,7 +69,7 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { const DATA_TYPE: DataType; /// Returns the byte width of this primitive type. - #[deprecated(note = "Use ArrowNativeType::get_byte_width")] + #[deprecated(since = "52.0.0", note = "Use ArrowNativeType::get_byte_width")] fn get_byte_width() -> usize { std::mem::size_of::() } @@ -324,12 +324,6 @@ pub trait ArrowTimestampType: ArrowTemporalType { /// The [`TimeUnit`] of this timestamp. const UNIT: TimeUnit; - /// Returns the `TimeUnit` of this timestamp. - #[deprecated(note = "Use Self::UNIT")] - fn get_time_unit() -> TimeUnit { - Self::UNIT - } - /// Creates a ArrowTimestampType::Native from the provided [`NaiveDateTime`] /// /// See [`DataType::Timestamp`] for more information on timezone handling diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index d2436f0c15de..c103c2ecc0f3 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -39,11 +39,9 @@ deflate = ["flate2"] snappy = ["snap", "crc"] [dependencies] -arrow-array = { workspace = true } -arrow-buffer = { workspace = true } -arrow-cast = { workspace = true } -arrow-data = { workspace = true } -arrow-schema = { workspace = true } +arrow-schema = { workspace = true } +arrow-buffer = { workspace = true } +arrow-array = { workspace = true } serde_json = { version = "1.0", default-features = false, features = ["std"] } serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } @@ -53,4 +51,5 @@ crc = { version = "3.0", optional = true } [dev-dependencies] +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } diff --git a/arrow-avro/LICENSE.txt b/arrow-avro/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-avro/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-avro/NOTICE.txt b/arrow-avro/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-avro/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 1e2acd99d828..2ac1ad038bd7 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -29,7 +29,7 @@ use std::sync::Arc; /// To accommodate this we special case two-variant unions where one of the /// variants is the null type, and use this to derive arrow's notion of nullability #[derive(Debug, Copy, Clone)] -enum Nulls { +pub enum Nullability { /// The nulls are encoded as the first union variant NullFirst, /// The nulls are encoded as the second union variant @@ -39,7 +39,7 @@ enum Nulls { /// An Avro datatype mapped to the arrow data model #[derive(Debug, Clone)] pub struct AvroDataType { - nulls: Option, + nullability: Option, metadata: HashMap, codec: Codec, } @@ -48,7 +48,15 @@ impl AvroDataType { /// Returns an arrow [`Field`] with the given name pub fn field_with_name(&self, name: &str) -> Field { let d = self.codec.data_type(); - Field::new(name, d, self.nulls.is_some()).with_metadata(self.metadata.clone()) + Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) + } + + pub fn codec(&self) -> &Codec { + &self.codec + } + + pub fn nullability(&self) -> Option { + self.nullability } } @@ -65,9 +73,13 @@ impl AvroField { self.data_type.field_with_name(&self.name) } - /// Returns the [`Codec`] - pub fn codec(&self) -> &Codec { - &self.data_type.codec + /// Returns the [`AvroDataType`] + pub fn data_type(&self) -> &AvroDataType { + &self.data_type + } + + pub fn name(&self) -> &str { + &self.name } } @@ -114,7 +126,7 @@ pub enum Codec { Fixed(i32), List(Arc), Struct(Arc<[AvroField]>), - Duration, + Interval, } impl Codec { @@ -137,9 +149,11 @@ impl Codec { Self::TimestampMicros(is_utc) => { DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } - Self::Duration => DataType::Interval(IntervalUnit::MonthDayNano), + Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::List(f) => DataType::List(Arc::new(f.field_with_name("item"))), + Self::List(f) => { + DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) + } Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), } } @@ -198,7 +212,7 @@ fn make_data_type<'a>( ) -> Result { match schema { Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { - nulls: None, + nullability: None, metadata: Default::default(), codec: (*p).into(), }), @@ -211,12 +225,12 @@ fn make_data_type<'a>( match (f.len() == 2, null) { (true, Some(0)) => { let mut field = make_data_type(&f[1], namespace, resolver)?; - field.nulls = Some(Nulls::NullFirst); + field.nullability = Some(Nullability::NullFirst); Ok(field) } (true, Some(1)) => { let mut field = make_data_type(&f[0], namespace, resolver)?; - field.nulls = Some(Nulls::NullSecond); + field.nullability = Some(Nullability::NullSecond); Ok(field) } _ => Err(ArrowError::NotYetImplemented(format!( @@ -239,7 +253,7 @@ fn make_data_type<'a>( .collect::>()?; let field = AvroDataType { - nulls: None, + nullability: None, codec: Codec::Struct(fields), metadata: r.attributes.field_metadata(), }; @@ -249,7 +263,7 @@ fn make_data_type<'a>( ComplexType::Array(a) => { let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; Ok(AvroDataType { - nulls: None, + nullability: None, metadata: a.attributes.field_metadata(), codec: Codec::List(Arc::new(field)), }) @@ -260,7 +274,7 @@ fn make_data_type<'a>( })?; let field = AvroDataType { - nulls: None, + nullability: None, metadata: f.attributes.field_metadata(), codec: Codec::Fixed(size), }; @@ -296,7 +310,7 @@ fn make_data_type<'a>( (Some("local-timestamp-micros"), c @ Codec::Int64) => { *c = Codec::TimestampMicros(false) } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, + (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, (Some(logical), _) => { // Insert unrecognized logical type into metadata map field.metadata.insert("logicalType".into(), logical.into()); diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs index c5c7a6dabc33..f29b8dd07606 100644 --- a/arrow-avro/src/compression.rs +++ b/arrow-avro/src/compression.rs @@ -16,7 +16,6 @@ // under the License. use arrow_schema::ArrowError; -use flate2::read; use std::io; use std::io::Read; @@ -35,7 +34,7 @@ impl CompressionCodec { match self { #[cfg(feature = "deflate")] CompressionCodec::Deflate => { - let mut decoder = read::DeflateDecoder::new(block); + let mut decoder = flate2::read::DeflateDecoder::new(block); let mut out = Vec::new(); decoder.read_to_end(&mut out)?; Ok(out) diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs new file mode 100644 index 000000000000..4b6a5a4d65db --- /dev/null +++ b/arrow-avro/src/reader/cursor.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::reader::vlq::read_varint; +use arrow_schema::ArrowError; + +/// A wrapper around a byte slice, providing low-level decoding for Avro +/// +/// +#[derive(Debug)] +pub(crate) struct AvroCursor<'a> { + buf: &'a [u8], + start_len: usize, +} + +impl<'a> AvroCursor<'a> { + pub(crate) fn new(buf: &'a [u8]) -> Self { + Self { + buf, + start_len: buf.len(), + } + } + + /// Returns the current cursor position + #[inline] + pub(crate) fn position(&self) -> usize { + self.start_len - self.buf.len() + } + + /// Read a single `u8` + #[inline] + pub(crate) fn get_u8(&mut self) -> Result { + match self.buf.first().copied() { + Some(x) => { + self.buf = &self.buf[1..]; + Ok(x) + } + None => Err(ArrowError::ParseError("Unexpected EOF".to_string())), + } + } + + #[inline] + pub(crate) fn get_bool(&mut self) -> Result { + Ok(self.get_u8()? != 0) + } + + pub(crate) fn read_vlq(&mut self) -> Result { + let (val, offset) = read_varint(self.buf) + .ok_or_else(|| ArrowError::ParseError("bad varint".to_string()))?; + self.buf = &self.buf[offset..]; + Ok(val) + } + + #[inline] + pub(crate) fn get_int(&mut self) -> Result { + let varint = self.read_vlq()?; + let val: u32 = varint + .try_into() + .map_err(|_| ArrowError::ParseError("varint overflow".to_string()))?; + Ok((val >> 1) as i32 ^ -((val & 1) as i32)) + } + + #[inline] + pub(crate) fn get_long(&mut self) -> Result { + let val = self.read_vlq()?; + Ok((val >> 1) as i64 ^ -((val & 1) as i64)) + } + + pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { + let len: usize = self.get_long()?.try_into().map_err(|_| { + ArrowError::ParseError("offset overflow reading avro bytes".to_string()) + })?; + + if (self.buf.len() < len) { + return Err(ArrowError::ParseError( + "Unexpected EOF reading bytes".to_string(), + )); + } + let ret = &self.buf[..len]; + self.buf = &self.buf[len..]; + Ok(ret) + } + + #[inline] + pub(crate) fn get_float(&mut self) -> Result { + if (self.buf.len() < 4) { + return Err(ArrowError::ParseError( + "Unexpected EOF reading float".to_string(), + )); + } + let ret = f32::from_le_bytes(self.buf[..4].try_into().unwrap()); + self.buf = &self.buf[4..]; + Ok(ret) + } + + #[inline] + pub(crate) fn get_double(&mut self) -> Result { + if (self.buf.len() < 8) { + return Err(ArrowError::ParseError( + "Unexpected EOF reading float".to_string(), + )); + } + let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); + self.buf = &self.buf[8..]; + Ok(ret) + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 19d48d1f89a1..98c285171bf3 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -19,7 +19,7 @@ use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; use crate::reader::vlq::VLQDecoder; -use crate::schema::Schema; +use crate::schema::{Schema, SCHEMA_METADATA_KEY}; use arrow_schema::ArrowError; #[derive(Debug)] @@ -89,6 +89,17 @@ impl Header { ))), } } + + /// Returns the [`Schema`] if any + pub fn schema(&self) -> Result>, ArrowError> { + self.get(SCHEMA_METADATA_KEY) + .map(|x| { + serde_json::from_slice(x).map_err(|e| { + ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}")) + }) + }) + .transpose() + } } /// A decoder for [`Header`] diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 0151db7f855a..12fa67d9c8e3 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -26,6 +26,8 @@ mod header; mod block; +mod cursor; +mod record; mod vlq; /// Read a [`Header`] from the provided [`BufRead`] @@ -73,35 +75,144 @@ fn read_blocks(mut reader: R) -> impl Iterator RecordBatch { + let file = File::open(file).unwrap(); + let mut reader = BufReader::new(file); + let header = read_header(&mut reader).unwrap(); + let compression = header.compression().unwrap(); + let schema = header.schema().unwrap().unwrap(); + let root = AvroField::try_from(&schema).unwrap(); + let mut decoder = RecordDecoder::try_new(root.data_type()).unwrap(); + + for result in read_blocks(reader) { + let block = result.unwrap(); + assert_eq!(block.sync, header.sync()); + if let Some(c) = compression { + let decompressed = c.decompress(&block.data).unwrap(); + + let mut offset = 0; + let mut remaining = block.count; + while remaining > 0 { + let to_read = remaining.max(batch_size); + offset += decoder + .decode(&decompressed[offset..], block.count) + .unwrap(); + + remaining -= to_read; + } + assert_eq!(offset, decompressed.len()); + } + } + decoder.flush().unwrap() + } #[test] - fn test_mux() { + fn test_alltypes() { let files = [ "avro/alltypes_plain.avro", "avro/alltypes_plain.snappy.avro", "avro/alltypes_plain.zstandard.avro", - "avro/alltypes_nulls_plain.avro", ]; + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([ + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + ])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + for file in files { - println!("file: {file}"); - let file = File::open(arrow_test_data(file)).unwrap(); - let mut reader = BufReader::new(file); - let header = read_header(&mut reader).unwrap(); - let compression = header.compression().unwrap(); - println!("compression: {compression:?}"); - for result in read_blocks(reader) { - let block = result.unwrap(); - assert_eq!(block.sync, header.sync()); - if let Some(c) = compression { - c.decompress(&block.data).unwrap(); - } - } + let file = arrow_test_data(file); + + assert_eq!(read_file(&file, 8), expected); + assert_eq!(read_file(&file, 3), expected); } } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs new file mode 100644 index 000000000000..52a58cf63303 --- /dev/null +++ b/arrow-avro/src/reader/record.rs @@ -0,0 +1,292 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::codec::{AvroDataType, Codec, Nullability}; +use crate::reader::block::{Block, BlockDecoder}; +use crate::reader::cursor::AvroCursor; +use crate::reader::header::Header; +use crate::schema::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::*; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, +}; +use std::collections::HashMap; +use std::io::Read; +use std::sync::Arc; + +/// Decodes avro encoded data into [`RecordBatch`] +pub struct RecordDecoder { + schema: SchemaRef, + fields: Vec, +} + +impl RecordDecoder { + pub fn try_new(data_type: &AvroDataType) -> Result { + match Decoder::try_new(data_type)? { + Decoder::Record(fields, encodings) => Ok(Self { + schema: Arc::new(ArrowSchema::new(fields)), + fields: encodings, + }), + encoding => Err(ArrowError::ParseError(format!( + "Expected record got {encoding:?}" + ))), + } + } + + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Decode `count` records from `buf` + pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { + let mut cursor = AvroCursor::new(buf); + for _ in 0..count { + for field in &mut self.fields { + field.decode(&mut cursor)?; + } + } + Ok(cursor.position()) + } + + /// Flush the decoded records into a [`RecordBatch`] + pub fn flush(&mut self) -> Result { + let arrays = self + .fields + .iter_mut() + .map(|x| x.flush(None)) + .collect::, _>>()?; + + RecordBatch::try_new(self.schema.clone(), arrays) + } +} + +#[derive(Debug)] +enum Decoder { + Null(usize), + Boolean(BooleanBufferBuilder), + Int32(Vec), + Int64(Vec), + Float32(Vec), + Float64(Vec), + Date32(Vec), + TimeMillis(Vec), + TimeMicros(Vec), + TimestampMillis(bool, Vec), + TimestampMicros(bool, Vec), + Binary(OffsetBufferBuilder, Vec), + String(OffsetBufferBuilder, Vec), + List(FieldRef, OffsetBufferBuilder, Box), + Record(Fields, Vec), + Nullable(Nullability, NullBufferBuilder, Box), +} + +impl Decoder { + fn try_new(data_type: &AvroDataType) -> Result { + let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); + + let decoder = match data_type.codec() { + Codec::Null => Self::Null(0), + Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Binary => Self::Binary( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + Codec::Utf8 => Self::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimestampMillis(is_utc) => { + Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::TimestampMicros(is_utc) => { + Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::Fixed(_) => return nyi("decoding fixed"), + Codec::Interval => return nyi("decoding interval"), + Codec::List(item) => { + let decoder = Self::try_new(item)?; + Self::List( + Arc::new(item.field_with_name("item")), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ) + } + Codec::Struct(fields) => { + let mut arrow_fields = Vec::with_capacity(fields.len()); + let mut encodings = Vec::with_capacity(fields.len()); + for avro_field in fields.iter() { + let encoding = Self::try_new(avro_field.data_type())?; + arrow_fields.push(avro_field.field()); + encodings.push(encoding); + } + Self::Record(arrow_fields.into(), encodings) + } + }; + + Ok(match data_type.nullability() { + Some(nullability) => Self::Nullable( + nullability, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ), + None => decoder, + }) + } + + /// Append a null record + fn append_null(&mut self) { + match self { + Self::Null(count) => *count += 1, + Self::Boolean(b) => b.append(false), + Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), + Self::Int64(v) + | Self::TimeMicros(v) + | Self::TimestampMillis(_, v) + | Self::TimestampMicros(_, v) => v.push(0), + Self::Float32(v) => v.push(0.), + Self::Float64(v) => v.push(0.), + Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), + Self::List(_, offsets, e) => { + offsets.push_length(0); + e.append_null(); + } + Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), + Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + } + } + + /// Decode a single record from `buf` + fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + match self { + Self::Null(x) => *x += 1, + Self::Boolean(values) => values.append(buf.get_bool()?), + Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { + values.push(buf.get_int()?) + } + Self::Int64(values) + | Self::TimeMicros(values) + | Self::TimestampMillis(_, values) + | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), + Self::Float32(values) => values.push(buf.get_float()?), + Self::Float64(values) => values.push(buf.get_double()?), + Self::Binary(offsets, values) | Self::String(offsets, values) => { + let data = buf.get_bytes()?; + offsets.push_length(data.len()); + values.extend_from_slice(data); + } + Self::List(_, _, _) => { + return Err(ArrowError::NotYetImplemented( + "Decoding ListArray".to_string(), + )) + } + Self::Record(_, encodings) => { + for encoding in encodings { + encoding.decode(buf)?; + } + } + Self::Nullable(nullability, nulls, e) => { + let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); + nulls.append(is_valid); + match is_valid { + true => e.decode(buf)?, + false => e.append_null(), + } + } + } + Ok(()) + } + + /// Flush decoded records to an [`ArrayRef`] + fn flush(&mut self, nulls: Option) -> Result { + Ok(match self { + Self::Nullable(_, n, e) => e.flush(n.finish())?, + Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), + Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), + Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), + Self::TimeMillis(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::TimeMicros(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::TimestampMillis(is_utc, values) => Arc::new( + flush_primitive::(values, nulls) + .with_timezone_opt(is_utc.then(|| "+00:00")), + ), + Self::TimestampMicros(is_utc, values) => Arc::new( + flush_primitive::(values, nulls) + .with_timezone_opt(is_utc.then(|| "+00:00")), + ), + Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), + + Self::Binary(offsets, values) => { + let offsets = flush_offsets(offsets); + let values = flush_values(values).into(); + Arc::new(BinaryArray::new(offsets, values, nulls)) + } + Self::String(offsets, values) => { + let offsets = flush_offsets(offsets); + let values = flush_values(values).into(); + Arc::new(StringArray::new(offsets, values, nulls)) + } + Self::List(field, offsets, values) => { + let values = values.flush(None)?; + let offsets = flush_offsets(offsets); + Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + } + Self::Record(fields, encodings) => { + let arrays = encodings + .iter_mut() + .map(|x| x.flush(None)) + .collect::, _>>()?; + Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + } + }) + } +} + +#[inline] +fn flush_values(values: &mut Vec) -> Vec { + std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +#[inline] +fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +#[inline] +fn flush_primitive( + values: &mut Vec, + nulls: Option, +) -> PrimitiveArray { + PrimitiveArray::new(flush_values(values).into(), nulls) +} + +const DEFAULT_CAPACITY: usize = 1024; diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index 80f1c60eec7d..b198a0d66f24 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -44,3 +44,91 @@ impl VLQDecoder { None } } + +/// Read a varint from `buf` returning the decoded `u64` and the number of bytes read +#[inline] +pub(crate) fn read_varint(buf: &[u8]) -> Option<(u64, usize)> { + let first = *buf.first()?; + if first < 0x80 { + return Some((first as u64, 1)); + } + + if let Some(array) = buf.get(..10) { + return read_varint_array(array.try_into().unwrap()); + } + + read_varint_slow(buf) +} + +/// Based on +/// - +/// - +/// - +#[inline] +fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { + let mut in_progress = 0_u64; + for (idx, b) in buf.into_iter().take(9).enumerate() { + in_progress += (b as u64) << (7 * idx); + if b < 0x80 { + return Some((in_progress, idx + 1)); + } + in_progress -= 0x80 << (7 * idx); + } + + let b = buf[9] as u64; + in_progress += b << (7 * 9); + (b < 0x02).then_some((in_progress, 10)) +} + +#[inline(never)] +#[cold] +fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { + let mut value = 0; + for (count, byte) in buf.iter().take(10).enumerate() { + let byte = buf[count]; + value |= u64::from(byte & 0x7F) << (count * 7); + if byte <= 0x7F { + // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details. + // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 + return (count != 9 || byte < 2).then_some((value, count + 1)); + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + fn encode_var(mut n: u64, dst: &mut [u8]) -> usize { + let mut i = 0; + + while n >= 0x80 { + dst[i] = 0x80 | (n as u8); + i += 1; + n >>= 7; + } + + dst[i] = n as u8; + i + 1 + } + + fn varint_test(a: u64) { + let mut buf = [0_u8; 10]; + let len = encode_var(a, &mut buf); + assert_eq!(read_varint(&buf[..len]).unwrap(), (a, len)); + assert_eq!(read_varint(&buf).unwrap(), (a, len)); + } + + #[test] + fn test_varint() { + varint_test(0); + varint_test(4395932); + varint_test(u64::MAX); + + for _ in 0..1000 { + varint_test(rand::random()); + } + } +} diff --git a/arrow-buffer/LICENSE.txt b/arrow-buffer/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-buffer/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-buffer/NOTICE.txt b/arrow-buffer/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-buffer/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs index 49a75b468dbe..aaa86832f692 100644 --- a/arrow-buffer/src/buffer/boolean.rs +++ b/arrow-buffer/src/buffer/boolean.rs @@ -52,8 +52,12 @@ impl BooleanBuffer { /// This method will panic if `buffer` is not large enough pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { let total_len = offset.saturating_add(len); - let bit_len = buffer.len().saturating_mul(8); - assert!(total_len <= bit_len); + let buffer_len = buffer.len(); + let bit_len = buffer_len.saturating_mul(8); + assert!( + total_len <= bit_len, + "buffer not large enough (offset: {offset}, len: {len}, buffer_len: {buffer_len})" + ); Self { buffer, offset, @@ -96,17 +100,6 @@ impl BooleanBuffer { BitChunks::new(self.values(), self.offset, self.len) } - /// Returns `true` if the bit at index `i` is set - /// - /// # Panics - /// - /// Panics if `i >= self.len()` - #[inline] - #[deprecated(note = "use BooleanBuffer::value")] - pub fn is_set(&self, i: usize) -> bool { - self.value(i) - } - /// Returns the offset of this [`BooleanBuffer`] in bits #[inline] pub fn offset(&self) -> usize { @@ -125,6 +118,12 @@ impl BooleanBuffer { self.len == 0 } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + // TODO(emilk): we could shrink even more in the case where we are a small sub-slice of the full buffer + self.buffer.shrink_to_fit(); + } + /// Returns the boolean value at index `i`. /// /// # Panics diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 8d1a46583fca..fd145ce2306e 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -20,16 +20,51 @@ use std::fmt::Debug; use std::ptr::NonNull; use std::sync::Arc; -use crate::alloc::{Allocation, Deallocation, ALIGNMENT}; +use crate::alloc::{Allocation, Deallocation}; use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; use crate::BufferBuilder; -use crate::{bytes::Bytes, native::ArrowNativeType}; +use crate::{bit_util, bytes::Bytes, native::ArrowNativeType}; use super::ops::bitwise_unary_op_helper; use super::{MutableBuffer, ScalarBuffer}; -/// Buffer represents a contiguous memory region that can be shared with other buffers and across -/// thread boundaries. +/// A contiguous memory region that can be shared with other buffers and across +/// thread boundaries that stores Arrow data. +/// +/// `Buffer`s can be sliced and cloned without copying the underlying data and can +/// be created from memory allocated by non-Rust sources such as C/C++. +/// +/// # Example: Create a `Buffer` from a `Vec` (without copying) +/// ``` +/// # use arrow_buffer::Buffer; +/// let vec: Vec = vec![1, 2, 3]; +/// let buffer = Buffer::from(vec); +/// ``` +/// +/// # Example: Convert a `Buffer` to a `Vec` (without copying) +/// +/// Use [`Self::into_vec`] to convert a `Buffer` back into a `Vec` if there are +/// no other references and the types are aligned correctly. +/// ``` +/// # use arrow_buffer::Buffer; +/// # let vec: Vec = vec![1, 2, 3]; +/// # let buffer = Buffer::from(vec); +/// // convert the buffer back into a Vec of u32 +/// // note this will fail if the buffer is shared or not aligned correctly +/// let vec: Vec = buffer.into_vec().unwrap(); +/// ``` +/// +/// # Example: Create a `Buffer` from a [`bytes::Bytes`] (without copying) +/// +/// [`bytes::Bytes`] is a common type in the Rust ecosystem for shared memory +/// regions. You can create a buffer from a `Bytes` instance using the `From` +/// implementation, also without copying. +/// +/// ``` +/// # use arrow_buffer::Buffer; +/// let bytes = bytes::Bytes::from("hello"); +/// let buffer = Buffer::from(bytes); +///``` #[derive(Clone, Debug)] pub struct Buffer { /// the internal byte buffer. @@ -59,16 +94,15 @@ unsafe impl Send for Buffer where Bytes: Send {} unsafe impl Sync for Buffer where Bytes: Sync {} impl Buffer { - /// Auxiliary method to create a new Buffer - #[inline] + /// Create a new Buffer from a (internal) `Bytes` + /// + /// NOTE despite the same name, `Bytes` is an internal struct in arrow-rs + /// and is different than [`bytes::Bytes`]. + /// + /// See examples on [`Buffer`] for ways to create a buffer from a [`bytes::Bytes`]. + #[deprecated(since = "54.1.0", note = "Use Buffer::from instead")] pub fn from_bytes(bytes: Bytes) -> Self { - let length = bytes.len(); - let ptr = bytes.as_ptr(); - Buffer { - data: Arc::new(bytes), - ptr, - length, - } + Self::from(bytes) } /// Returns the offset, in bytes, of `Self::ptr` to `Self::data` @@ -99,28 +133,11 @@ impl Buffer { buffer.into() } - /// Creates a buffer from an existing aligned memory region (must already be byte-aligned), this - /// `Buffer` will free this piece of memory when dropped. + /// Creates a buffer from an existing memory region. /// - /// # Arguments - /// - /// * `ptr` - Pointer to raw parts - /// * `len` - Length of raw parts in **bytes** - /// * `capacity` - Total allocated memory for the pointer `ptr`, in **bytes** - /// - /// # Safety - /// - /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` - /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. - #[deprecated(note = "Use Buffer::from_vec")] - pub unsafe fn from_raw_parts(ptr: NonNull, len: usize, capacity: usize) -> Self { - assert!(len <= capacity); - let layout = Layout::from_size_align(capacity, ALIGNMENT).unwrap(); - Buffer::build_with_arguments(ptr, len, Deallocation::Standard(layout)) - } - - /// Creates a buffer from an existing memory region. Ownership of the memory is tracked via reference counting - /// and the memory will be freed using the `drop` method of [crate::alloc::Allocation] when the reference count reaches zero. + /// Ownership of the memory is tracked via reference counting + /// and the memory will be freed using the `drop` method of + /// [crate::alloc::Allocation] when the reference count reaches zero. /// /// # Arguments /// @@ -167,7 +184,42 @@ impl Buffer { self.data.capacity() } - /// Returns whether the buffer is empty. + /// Tries to shrink the capacity of the buffer as much as possible, freeing unused memory. + /// + /// If the buffer is shared, this is a no-op. + /// + /// If the memory was allocated with a custom allocator, this is a no-op. + /// + /// If the capacity is already less than or equal to the desired capacity, this is a no-op. + /// + /// The memory region will be reallocated using `std::alloc::realloc`. + pub fn shrink_to_fit(&mut self) { + let offset = self.ptr_offset(); + let is_empty = self.is_empty(); + let desired_capacity = if is_empty { + 0 + } else { + // For realloc to work, we cannot free the elements before the offset + offset + self.len() + }; + if desired_capacity < self.capacity() { + if let Some(bytes) = Arc::get_mut(&mut self.data) { + if bytes.try_realloc(desired_capacity).is_ok() { + // Realloc complete - update our pointer into `bytes`: + self.ptr = if is_empty { + bytes.as_ptr() + } else { + // SAFETY: we kept all elements leading up to the offset + unsafe { bytes.as_ptr().add(offset) } + } + } else { + // Failure to reallocate is fine; we just failed to free up memory. + } + } + } + } + + /// Returns true if the buffer is empty. #[inline] pub fn is_empty(&self) -> bool { self.length == 0 @@ -183,7 +235,9 @@ impl Buffer { } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. - /// Doing so allows the same memory region to be shared between buffers. + /// + /// This function is `O(1)` and does not copy any data, allowing the + /// same memory region to be shared between buffers. /// /// # Panics /// @@ -217,7 +271,10 @@ impl Buffer { /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`, /// with `length` bytes. - /// Doing so allows the same memory region to be shared between buffers. + /// + /// This function is `O(1)` and does not copy any data, allowing the same + /// memory region to be shared between buffers. + /// /// # Panics /// Panics iff `(offset + length)` is larger than the existing length. pub fn slice_with_length(&self, offset: usize, length: usize) -> Self { @@ -265,7 +322,7 @@ impl Buffer { /// otherwise a new buffer is allocated and filled with a copy of the bits in the range. pub fn bit_slice(&self, offset: usize, len: usize) -> Self { if offset % 8 == 0 { - return self.slice(offset / 8); + return self.slice_with_length(offset / 8, bit_util::ceil(len, 8)); } bitwise_unary_op_helper(self, offset, len, |a| a) @@ -278,14 +335,6 @@ impl Buffer { BitChunks::new(self.as_slice(), offset, len) } - /// Returns the number of 1-bits in this buffer. - #[deprecated(note = "use count_set_bits_offset instead")] - pub fn count_set_bits(&self) -> usize { - let len_in_bits = self.len() * 8; - // self.offset is already taken into consideration by the bit_chunks implementation - self.count_set_bits_offset(0, len_in_bits) - } - /// Returns the number of 1-bits in this buffer, starting from `offset` with `length` bits /// inspected. Note that both `offset` and `length` are measured in bits. pub fn count_set_bits_offset(&self, offset: usize, len: usize) -> usize { @@ -295,6 +344,8 @@ impl Buffer { /// Returns `MutableBuffer` for mutating the buffer if this buffer is not shared. /// Returns `Err` if this is shared or its allocation is from an external source or /// it is not allocated with alignment [`ALIGNMENT`] + /// + /// [`ALIGNMENT`]: crate::alloc::ALIGNMENT pub fn into_mutable(self) -> Result { let ptr = self.ptr; let length = self.length; @@ -311,10 +362,16 @@ impl Buffer { }) } - /// Returns `Vec` for mutating the buffer + /// Converts self into a `Vec`, if possible. /// - /// Returns `Err(self)` if this buffer does not have the same [`Layout`] as - /// the destination Vec or contains a non-zero offset + /// This can be used to reuse / mutate the underlying data. + /// + /// # Errors + /// + /// Returns `Err(self)` if + /// 1. this buffer does not have the same [`Layout`] as the destination Vec + /// 2. contains a non-zero offset + /// 3. The buffer is shared pub fn into_vec(self) -> Result, Self> { let layout = match self.data.deallocation() { Deallocation::Standard(l) => l, @@ -397,7 +454,29 @@ impl From> for Buffer { } } -/// Creating a `Buffer` instance by storing the boolean values into the buffer +/// Convert from internal `Bytes` (not [`bytes::Bytes`]) to `Buffer` +impl From for Buffer { + #[inline] + fn from(bytes: Bytes) -> Self { + let length = bytes.len(); + let ptr = bytes.as_ptr(); + Self { + data: Arc::new(bytes), + ptr, + length, + } + } +} + +/// Convert from [`bytes::Bytes`], not internal `Bytes` to `Buffer` +impl From for Buffer { + fn from(bytes: bytes::Bytes) -> Self { + let bytes: Bytes = bytes.into(); + Self::from(bytes) + } +} + +/// Create a `Buffer` instance by storing the boolean values into the buffer impl FromIterator for Buffer { fn from_iter(iter: I) -> Self where @@ -430,7 +509,9 @@ impl From> for Buffer { impl Buffer { /// Creates a [`Buffer`] from an [`Iterator`] with a trusted (upper) length. + /// /// Prefer this to `collect` whenever possible, as it is ~60% faster. + /// /// # Example /// ``` /// # use arrow_buffer::buffer::Buffer; @@ -562,6 +643,34 @@ mod tests { assert_eq!(buf2.slice_with_length(2, 1).as_slice(), &[10]); } + #[test] + fn test_shrink_to_fit() { + let original = Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(original.as_slice(), &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(original.capacity(), 64); + + let slice = original.slice_with_length(2, 3); + drop(original); // Make sure the buffer isn't shared (or shrink_to_fit won't work) + assert_eq!(slice.as_slice(), &[2, 3, 4]); + assert_eq!(slice.capacity(), 64); + + let mut shrunk = slice; + shrunk.shrink_to_fit(); + assert_eq!(shrunk.as_slice(), &[2, 3, 4]); + assert_eq!(shrunk.capacity(), 5); // shrink_to_fit is allowed to keep the elements before the offset + + // Test that we can handle empty slices: + let empty_slice = shrunk.slice_with_length(1, 0); + drop(shrunk); // Make sure the buffer isn't shared (or shrink_to_fit won't work) + assert_eq!(empty_slice.as_slice(), &[]); + assert_eq!(empty_slice.capacity(), 5); + + let mut shrunk_empty = empty_slice; + shrunk_empty.shrink_to_fit(); + assert_eq!(shrunk_empty.as_slice(), &[]); + assert_eq!(shrunk_empty.capacity(), 0); + } + #[test] #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_slice_offset_out_of_bound() { @@ -860,4 +969,37 @@ mod tests { let iter_len = usize::MAX / std::mem::size_of::() + 1; let _ = Buffer::from_iter(std::iter::repeat(0_u64).take(iter_len)); } + + #[test] + fn bit_slice_length_preserved() { + // Create a boring buffer + let buf = Buffer::from_iter(std::iter::repeat(true).take(64)); + + let assert_preserved = |offset: usize, len: usize| { + let new_buf = buf.bit_slice(offset, len); + assert_eq!(new_buf.len(), bit_util::ceil(len, 8)); + + // if the offset is not byte-aligned, we have to create a deep copy to a new buffer + // (since the `offset` value inside a Buffer is byte-granular, not bit-granular), so + // checking the offset should always return 0 if so. If the offset IS byte-aligned, we + // want to make sure it doesn't unnecessarily create a deep copy. + if offset % 8 == 0 { + assert_eq!(new_buf.ptr_offset(), offset / 8); + } else { + assert_eq!(new_buf.ptr_offset(), 0); + } + }; + + // go through every available value for offset + for o in 0..=64 { + // and go through every length that could accompany that offset - we can't have a + // situation where offset + len > 64, because that would go past the end of the buffer, + // so we use the map to ensure it's in range. + for l in (o..=64).map(|l| l - o) { + // and we just want to make sure every one of these keeps its offset and length + // when neeeded + assert_preserved(o, l); + } + } + } } diff --git a/arrow-buffer/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs index 7fcbd89dd262..5ad55e306e2a 100644 --- a/arrow-buffer/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -118,13 +118,6 @@ impl MutableBuffer { Self { data, len, layout } } - /// Create a [`MutableBuffer`] from the provided [`Vec`] without copying - #[inline] - #[deprecated(note = "Use From>")] - pub fn from_vec(vec: Vec) -> Self { - Self::from(vec) - } - /// Allocates a new [MutableBuffer] from given `Bytes`. pub(crate) fn from_bytes(bytes: Bytes) -> Result { let layout = match bytes.deallocation() { @@ -331,20 +324,11 @@ impl MutableBuffer { self.data.as_ptr() } - #[deprecated( - since = "2.0.0", - note = "This method is deprecated in favour of `into` from the trait `Into`." - )] - /// Freezes this buffer and return an immutable version of it. - pub fn freeze(self) -> Buffer { - self.into_buffer() - } - #[inline] pub(super) fn into_buffer(self) -> Buffer { let bytes = unsafe { Bytes::new(self.data, self.len, Deallocation::Standard(self.layout)) }; std::mem::forget(self); - Buffer::from_bytes(bytes) + Buffer::from(bytes) } /// View this buffer as a mutable slice of a specific type. @@ -483,10 +467,13 @@ impl MutableBuffer { } } +/// Creates a non-null pointer with alignment of [`ALIGNMENT`] +/// +/// This is similar to [`NonNull::dangling`] #[inline] -fn dangling_ptr() -> NonNull { - // SAFETY: ALIGNMENT is a non-zero usize which is then casted - // to a *mut T. Therefore, `ptr` is not null and the conditions for +pub(crate) fn dangling_ptr() -> NonNull { + // SAFETY: ALIGNMENT is a non-zero usize which is then cast + // to a *mut u8. Therefore, `ptr` is not null and the conditions for // calling new_unchecked() are respected. #[cfg(miri)] { diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs index c79aef398059..ec12b885eb5a 100644 --- a/arrow-buffer/src/buffer/null.rs +++ b/arrow-buffer/src/buffer/null.rs @@ -130,6 +130,11 @@ impl NullBuffer { self.buffer.is_empty() } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + /// Returns the null count for this [`NullBuffer`] #[inline] pub fn null_count(&self) -> usize { @@ -235,6 +240,12 @@ impl From<&[bool]> for NullBuffer { } } +impl From<&[bool; N]> for NullBuffer { + fn from(value: &[bool; N]) -> Self { + value[..].into() + } +} + impl From> for NullBuffer { fn from(value: Vec) -> Self { BooleanBuffer::from(value).into() diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs index e9087d30098c..164af6f01d0e 100644 --- a/arrow-buffer/src/buffer/offset.rs +++ b/arrow-buffer/src/buffer/offset.rs @@ -133,6 +133,43 @@ impl OffsetBuffer { Self(out.into()) } + /// Get an Iterator over the lengths of this [`OffsetBuffer`] + /// + /// ``` + /// # use arrow_buffer::{OffsetBuffer, ScalarBuffer}; + /// let offsets = OffsetBuffer::<_>::new(ScalarBuffer::::from(vec![0, 1, 4, 9])); + /// assert_eq!(offsets.lengths().collect::>(), vec![1, 3, 5]); + /// ``` + /// + /// Empty [`OffsetBuffer`] will return an empty iterator + /// ``` + /// # use arrow_buffer::OffsetBuffer; + /// let offsets = OffsetBuffer::::new_empty(); + /// assert_eq!(offsets.lengths().count(), 0); + /// ``` + /// + /// This can be used to merge multiple [`OffsetBuffer`]s to one + /// ``` + /// # use arrow_buffer::{OffsetBuffer, ScalarBuffer}; + /// + /// let buffer1 = OffsetBuffer::::from_lengths([2, 6, 3, 7, 2]); + /// let buffer2 = OffsetBuffer::::from_lengths([1, 3, 5, 7, 9]); + /// + /// let merged = OffsetBuffer::::from_lengths( + /// vec![buffer1, buffer2].iter().flat_map(|x| x.lengths()) + /// ); + /// + /// assert_eq!(merged.lengths().collect::>(), &[2, 6, 3, 7, 2, 1, 3, 5, 7, 9]); + /// ``` + pub fn lengths(&self) -> impl ExactSizeIterator + '_ { + self.0.windows(2).map(|x| x[1].as_usize() - x[0].as_usize()) + } + + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit(); + } + /// Returns the inner [`ScalarBuffer`] pub fn inner(&self) -> &ScalarBuffer { &self.0 @@ -239,4 +276,24 @@ mod tests { fn from_lengths_usize_overflow() { OffsetBuffer::::from_lengths([usize::MAX, 1]); } + + #[test] + fn get_lengths() { + let offsets = OffsetBuffer::::new(ScalarBuffer::::from(vec![0, 1, 4, 9])); + assert_eq!(offsets.lengths().collect::>(), vec![1, 3, 5]); + } + + #[test] + fn get_lengths_should_be_with_fixed_size() { + let offsets = OffsetBuffer::::new(ScalarBuffer::::from(vec![0, 1, 4, 9])); + let iter = offsets.lengths(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!(iter.len(), 3); + } + + #[test] + fn get_lengths_from_empty_offset_buffer_should_be_empty_iterator() { + let offsets = OffsetBuffer::::new_empty(); + assert_eq!(offsets.lengths().collect::>(), vec![]); + } } diff --git a/arrow-buffer/src/buffer/run.rs b/arrow-buffer/src/buffer/run.rs index 3dbbe344a025..cc6d19044feb 100644 --- a/arrow-buffer/src/buffer/run.rs +++ b/arrow-buffer/src/buffer/run.rs @@ -136,6 +136,12 @@ where self.len == 0 } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + // TODO(emilk): we could shrink even more in the case where we are a small sub-slice of the full buffer + self.run_ends.shrink_to_fit(); + } + /// Returns the values of this [`RunEndBuffer`] not including any offset #[inline] pub fn values(&self) -> &[E] { diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 343b8549e93d..ab6c87168e5c 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -72,6 +72,11 @@ impl ScalarBuffer { buffer.slice_with_length(byte_offset, byte_len).into() } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + /// Returns a zero-copy slice of this buffer with length `len` and starting at `offset` pub fn slice(&self, offset: usize, len: usize) -> Self { Self::new(self.buffer.clone(), offset, len) diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index ba61342d8e39..b811bd2c6b40 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -24,17 +24,22 @@ use std::ptr::NonNull; use std::{fmt::Debug, fmt::Formatter}; use crate::alloc::Deallocation; +use crate::buffer::dangling_ptr; /// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. /// -/// This structs' API is inspired by the `bytes::Bytes`, but it is not limited to using rust's -/// global allocator nor u8 alignment. +/// Note that this structure is an internal implementation detail of the +/// arrow-rs crate. While it has the same name and similar API as +/// [`bytes::Bytes`] it is not limited to rust's global allocator nor u8 +/// alignment. It is possible to create a `Bytes` from `bytes::Bytes` using the +/// `From` implementation. /// /// In the most common case, this buffer is allocated using [`alloc`](std::alloc::alloc) /// with an alignment of [`ALIGNMENT`](crate::alloc::ALIGNMENT) /// /// When the region is allocated by a different allocator, [Deallocation::Custom], this calls the /// custom deallocator to deallocate the region when it is no longer needed. +/// pub struct Bytes { /// The raw pointer to be beginning of the region ptr: NonNull, @@ -96,6 +101,48 @@ impl Bytes { } } + /// Try to reallocate the underlying memory region to a new size (smaller or larger). + /// + /// Only works for bytes allocated with the standard allocator. + /// Returns `Err` if the memory was allocated with a custom allocator, + /// or the call to `realloc` failed, for whatever reason. + /// In case of `Err`, the [`Bytes`] will remain as it was (i.e. have the old size). + pub fn try_realloc(&mut self, new_len: usize) -> Result<(), ()> { + if let Deallocation::Standard(old_layout) = self.deallocation { + if old_layout.size() == new_len { + return Ok(()); // Nothing to do + } + + if let Ok(new_layout) = std::alloc::Layout::from_size_align(new_len, old_layout.align()) + { + let old_ptr = self.ptr.as_ptr(); + + let new_ptr = match new_layout.size() { + 0 => { + // SAFETY: Verified that old_layout.size != new_len (0) + unsafe { std::alloc::dealloc(self.ptr.as_ptr(), old_layout) }; + Some(dangling_ptr()) + } + // SAFETY: the call to `realloc` is safe if all the following hold (from https://doc.rust-lang.org/stable/std/alloc/trait.GlobalAlloc.html#method.realloc): + // * `old_ptr` must be currently allocated via this allocator (guaranteed by the invariant/contract of `Bytes`) + // * `old_layout` must be the same layout that was used to allocate that block of memory (same) + // * `new_len` must be greater than zero + // * `new_len`, when rounded up to the nearest multiple of `layout.align()`, must not overflow `isize` (guaranteed by the success of `Layout::from_size_align`) + _ => NonNull::new(unsafe { std::alloc::realloc(old_ptr, old_layout, new_len) }), + }; + + if let Some(ptr) = new_ptr { + self.ptr = ptr; + self.len = new_len; + self.deallocation = Deallocation::Standard(new_layout); + return Ok(()); + } + } + } + + Err(()) + } + #[inline] pub(crate) fn deallocation(&self) -> &Deallocation { &self.deallocation diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs index c563f73cf5b9..eb8e067db0be 100644 --- a/arrow-buffer/src/native.rs +++ b/arrow-buffer/src/native.rs @@ -88,30 +88,6 @@ pub trait ArrowNativeType: /// Returns `None` if [`Self`] is not an integer or conversion would result /// in truncation/overflow fn to_i64(self) -> Option; - - /// Convert native type from i32. - /// - /// Returns `None` if [`Self`] is not `i32` - #[deprecated(note = "please use `Option::Some` instead")] - fn from_i32(_: i32) -> Option { - None - } - - /// Convert native type from i64. - /// - /// Returns `None` if [`Self`] is not `i64` - #[deprecated(note = "please use `Option::Some` instead")] - fn from_i64(_: i64) -> Option { - None - } - - /// Convert native type from i128. - /// - /// Returns `None` if [`Self`] is not `i128` - #[deprecated(note = "please use `Option::Some` instead")] - fn from_i128(_: i128) -> Option { - None - } } macro_rules! native_integer { @@ -147,23 +123,15 @@ macro_rules! native_integer { fn usize_as(i: usize) -> Self { i as _ } - - - $( - #[inline] - fn $from(v: $t) -> Option { - Some(v) - } - )* } }; } native_integer!(i8); native_integer!(i16); -native_integer!(i32, from_i32); -native_integer!(i64, from_i64); -native_integer!(i128, from_i128); +native_integer!(i32); +native_integer!(i64); +native_integer!(i128); native_integer!(u8); native_integer!(u16); native_integer!(u32); diff --git a/arrow-cast/LICENSE.txt b/arrow-cast/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-cast/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-cast/NOTICE.txt b/arrow-cast/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-cast/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index d6b2f884f753..ba82ca9040c7 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -111,9 +111,13 @@ where O::Native::from_decimal(adjusted) }; - Ok(match cast_options.safe { - true => array.unary_opt(f), - false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + Ok(if cast_options.safe { + array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + } else { + array.try_unary(|x| { + f(x).ok_or_else(|| error(x)) + .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) + })? }) } @@ -137,15 +141,20 @@ where let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - Ok(match cast_options.safe { - true => array.unary_opt(f), - false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + Ok(if cast_options.safe { + array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + } else { + array.try_unary(|x| { + f(x).ok_or_else(|| error(x)) + .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) + })? }) } // Only support one type of decimal cast operations pub(crate) fn cast_decimal_to_decimal_same_type( array: &PrimitiveArray, + input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8, @@ -155,20 +164,11 @@ where T: DecimalType, T::Native: DecimalCast + ArrowNativeTypeOp, { - let array: PrimitiveArray = match input_scale.cmp(&output_scale) { - Ordering::Equal => { - // the scale doesn't change, the native value don't need to be changed + let array: PrimitiveArray = + if input_scale == output_scale && input_precision <= output_precision { array.clone() - } - Ordering::Greater => convert_to_smaller_scale_decimal::( - array, - input_scale, - output_precision, - output_scale, - cast_options, - )?, - Ordering::Less => { - // input_scale < output_scale + } else if input_scale < output_scale { + // the scale doesn't change, but precision may change and cause overflow convert_to_bigger_or_equal_scale_decimal::( array, input_scale, @@ -176,8 +176,15 @@ where output_scale, cast_options, )? - } - }; + } else { + convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + }; Ok(Arc::new(array.with_precision_and_scale( output_precision, @@ -323,8 +330,8 @@ where }) } -pub(crate) fn string_to_decimal_cast( - from: &GenericStringArray, +pub(crate) fn generic_string_to_decimal_cast<'a, T, S>( + from: &'a S, precision: u8, scale: i8, cast_options: &CastOptions, @@ -332,6 +339,7 @@ pub(crate) fn string_to_decimal_cast( where T: DecimalType, T::Native: DecimalCast + ArrowNativeTypeOp, + &'a S: StringArrayType<'a>, { if cast_options.safe { let iter = from.iter().map(|v| { @@ -375,6 +383,37 @@ where } } +pub(crate) fn string_to_decimal_cast( + from: &GenericStringArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + generic_string_to_decimal_cast::>( + from, + precision, + scale, + cast_options, + ) +} + +pub(crate) fn string_view_to_decimal_cast( + from: &StringViewArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + generic_string_to_decimal_cast::(from, precision, scale, cast_options) +} + /// Cast Utf8 to decimal pub(crate) fn cast_string_to_decimal( from: &dyn Array, @@ -399,14 +438,30 @@ where ))); } - Ok(Arc::new(string_to_decimal_cast::( - from.as_any() - .downcast_ref::>() - .unwrap(), - precision, - scale, - cast_options, - )?)) + let result = match from.data_type() { + DataType::Utf8View => string_view_to_decimal_cast::( + from.as_any().downcast_ref::().unwrap(), + precision, + scale, + cast_options, + )?, + DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::( + from.as_any() + .downcast_ref::>() + .unwrap(), + precision, + scale, + cast_options, + )?, + other => { + return Err(ArrowError::ComputeError(format!( + "Cannot cast {:?} to decimal", + other + ))) + } + }; + + Ok(Arc::new(result)) } pub(crate) fn cast_floating_point_to_decimal128( diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index f7059be170f4..440d0a8becde 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -182,10 +182,10 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) | // decimal to signed numeric (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true, - // decimal to Utf8 - (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, - // Utf8 to decimal - (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, + // decimal to string + (Decimal128(_, _) | Decimal256(_, _), Utf8View | Utf8 | LargeUtf8) => true, + // string to decimal + (Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, (Struct(from_fields), Struct(to_fields)) => { from_fields.len() == to_fields.len() && from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { @@ -197,13 +197,18 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Struct(_), _) => false, (_, Struct(_)) => false, (_, Boolean) => { - DataType::is_integer(from_type) || - DataType::is_floating(from_type) + DataType::is_integer(from_type) + || DataType::is_floating(from_type) + || from_type == &Utf8View || from_type == &Utf8 || from_type == &LargeUtf8 } (Boolean, _) => { - DataType::is_integer(to_type) || DataType::is_floating(to_type) || to_type == &Utf8 || to_type == &LargeUtf8 + DataType::is_integer(to_type) + || DataType::is_floating(to_type) + || to_type == &Utf8View + || to_type == &Utf8 + || to_type == &LargeUtf8 } (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View ) => true, @@ -230,8 +235,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { ) => true, (Utf8 | LargeUtf8, Utf8View) => true, (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true, - (Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, + (Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, (_, Utf8 | LargeUtf8) => from_type.is_primitive(), + (_, Utf8View) => from_type.is_numeric(), (_, Binary | LargeBinary) => from_type.is_integer(), @@ -824,18 +830,20 @@ pub fn cast_with_options( (Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => { cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) } - (Decimal128(_, s1), Decimal128(p2, s2)) => { + (Decimal128(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), + *p1, *s1, *p2, *s2, cast_options, ) } - (Decimal256(_, s1), Decimal256(p2, s2)) => { + (Decimal256(p1, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), + *p1, *s1, *p2, *s2, @@ -917,6 +925,7 @@ pub fn cast_with_options( Float64 => cast_decimal_to_float::(array, |x| { x as f64 / 10_f64.powi(*scale as i32) }), + Utf8View => value_to_string_view(array, cast_options), Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), @@ -982,6 +991,7 @@ pub fn cast_with_options( Float64 => cast_decimal_to_float::(array, |x| { x.to_f64().unwrap() / 10_f64.powi(*scale as i32) }), + Utf8View => value_to_string_view(array, cast_options), Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), @@ -1061,7 +1071,7 @@ pub fn cast_with_options( *scale, cast_options, ), - Utf8 => cast_string_to_decimal::( + Utf8View | Utf8 => cast_string_to_decimal::( array, *precision, *scale, @@ -1150,7 +1160,7 @@ pub fn cast_with_options( *scale, cast_options, ), - Utf8 => cast_string_to_decimal::( + Utf8View | Utf8 => cast_string_to_decimal::( array, *precision, *scale, @@ -1179,12 +1189,12 @@ pub fn cast_with_options( let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; Ok(Arc::new(array) as ArrayRef) } - (Struct(_), _) => Err(ArrowError::CastError( - "Cannot cast from struct to other types except struct".to_string(), - )), - (_, Struct(_)) => Err(ArrowError::CastError( - "Cannot cast to struct from other types except struct".to_string(), - )), + (Struct(_), _) => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + (_, Struct(_)) => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), UInt16 => cast_numeric_to_bool::(array), @@ -1197,6 +1207,7 @@ pub fn cast_with_options( Float16 => cast_numeric_to_bool::(array), Float32 => cast_numeric_to_bool::(array), Float64 => cast_numeric_to_bool::(array), + Utf8View => cast_utf8view_to_boolean(array, cast_options), Utf8 => cast_utf8_to_boolean::(array, cast_options), LargeUtf8 => cast_utf8_to_boolean::(array, cast_options), _ => Err(ArrowError::CastError(format!( @@ -1215,6 +1226,7 @@ pub fn cast_with_options( Float16 => cast_bool_to_numeric::(array, cast_options), Float32 => cast_bool_to_numeric::(array, cast_options), Float64 => cast_bool_to_numeric::(array, cast_options), + Utf8View => value_to_string_view(array, cast_options), Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( @@ -1462,6 +1474,9 @@ pub fn cast_with_options( (BinaryView, _) => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported", ))), + (from_type, Utf8View) if from_type.is_primitive() => { + value_to_string_view(array, cast_options) + } (from_type, LargeUtf8) if from_type.is_primitive() => { value_to_string::(array, cast_options) } @@ -2485,12 +2500,11 @@ where #[cfg(test)] mod tests { + use super::*; use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer}; use chrono::NaiveDate; use half::f16; - use super::*; - macro_rules! generate_cast_test_case { ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { let output = @@ -2682,13 +2696,16 @@ mod tests { // negative test let array = vec![Some(123456), None]; let array = create_decimal_array(array, 10, 0).unwrap(); - let result = cast(&array, &DataType::Decimal128(2, 2)); - assert!(result.is_ok()); - let array = result.unwrap(); - let array: &Decimal128Array = array.as_primitive(); - let err = array.validate_decimal_precision(2); + let result_safe = cast(&array, &DataType::Decimal128(2, 2)); + assert!(result_safe.is_ok()); + let options = CastOptions { + safe: false, + ..Default::default() + }; + + let result_unsafe = cast_with_options(&array, &DataType::Decimal128(2, 2), &options); assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", - err.unwrap_err().to_string()); + result_unsafe.unwrap_err().to_string()); } #[test] @@ -3637,7 +3654,7 @@ mod tests { let array = Int32Array::from(vec![5, 6, 7, 8, 9]); let b = cast( &array, - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), ) .unwrap(); assert_eq!(5, b.len()); @@ -3661,7 +3678,7 @@ mod tests { let array = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]); let b = cast( &array, - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), ) .unwrap(); assert_eq!(5, b.len()); @@ -3689,7 +3706,7 @@ mod tests { let array = array.slice(2, 4); let b = cast( &array, - &DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))), ) .unwrap(); assert_eq!(4, b.len()); @@ -3708,6 +3725,54 @@ mod tests { assert_eq!(10.0, c.value(3)); } + #[test] + fn test_cast_int_to_utf8view() { + let inputs = vec![ + Arc::new(Int8Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(Int16Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(Int32Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(Int64Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(UInt8Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(UInt16Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(UInt32Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + Arc::new(UInt64Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef, + ]; + let expected: ArrayRef = Arc::new(StringViewArray::from(vec![ + None, + Some("8"), + Some("9"), + Some("10"), + ])); + + for array in inputs { + assert!(can_cast_types(array.data_type(), &DataType::Utf8View)); + let arr = cast(&array, &DataType::Utf8View).unwrap(); + assert_eq!(expected.as_ref(), arr.as_ref()); + } + } + + #[test] + fn test_cast_float_to_utf8view() { + let inputs = vec![ + Arc::new(Float16Array::from(vec![ + Some(f16::from_f64(1.5)), + Some(f16::from_f64(2.5)), + None, + ])) as ArrayRef, + Arc::new(Float32Array::from(vec![Some(1.5), Some(2.5), None])) as ArrayRef, + Arc::new(Float64Array::from(vec![Some(1.5), Some(2.5), None])) as ArrayRef, + ]; + + let expected: ArrayRef = + Arc::new(StringViewArray::from(vec![Some("1.5"), Some("2.5"), None])); + + for array in inputs { + assert!(can_cast_types(array.data_type(), &DataType::Utf8View)); + let arr = cast(&array, &DataType::Utf8View).unwrap(); + assert_eq!(expected.as_ref(), arr.as_ref()); + } + } + #[test] fn test_cast_utf8_to_i32() { let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); @@ -3720,6 +3785,41 @@ mod tests { assert!(!c.is_valid(4)); } + #[test] + fn test_cast_utf8view_to_i32() { + let array = StringViewArray::from(vec!["5", "6", "seven", "8", "9.1"]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(8, c.value(3)); + assert!(!c.is_valid(4)); + } + + #[test] + fn test_cast_utf8view_to_f32() { + let array = StringViewArray::from(vec!["3", "4.56", "seven", "8.9"]); + let b = cast(&array, &DataType::Float32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3.0, c.value(0)); + assert_eq!(4.56, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(8.9, c.value(3)); + } + + #[test] + fn test_cast_utf8view_to_decimal128() { + let array = StringViewArray::from(vec![None, Some("4"), Some("5.6"), Some("7.89")]); + let arr = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &arr, + Decimal128Array, + &DataType::Decimal128(4, 2), + vec![None, Some(400_i128), Some(560_i128), Some(789_i128)] + ); + } + #[test] fn test_cast_with_options_utf8_to_i32() { let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); @@ -3751,6 +3851,14 @@ mod tests { assert_eq!(*as_boolean_array(&casted), expected); } + #[test] + fn test_cast_utf8view_to_bool() { + let strings = StringViewArray::from(vec!["true", "false", "invalid", " Y ", ""]); + let casted = cast(&strings, &DataType::Boolean).unwrap(); + let expected = BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); + assert_eq!(*as_boolean_array(&casted), expected); + } + #[test] fn test_cast_with_options_utf8_to_bool() { let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); @@ -3782,6 +3890,16 @@ mod tests { assert!(!c.is_valid(2)); } + #[test] + fn test_cast_bool_to_utf8view() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Utf8View).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + #[test] fn test_cast_bool_to_utf8() { let array = BooleanArray::from(vec![Some(true), Some(false), None]); @@ -3975,7 +4093,7 @@ mod tests { // Construct a list array from the above two // [[0,0,0], [-1, -2, -1], [2, 100000000]] - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -3986,7 +4104,7 @@ mod tests { let cast_array = cast( &list_array, - &DataType::List(Arc::new(Field::new("item", DataType::UInt16, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::UInt16, true))), ) .unwrap(); @@ -4026,7 +4144,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 9]); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -4037,8 +4155,7 @@ mod tests { let actual = cast( &list_array, - &DataType::List(Arc::new(Field::new( - "item", + &DataType::List(Arc::new(Field::new_list_field( DataType::Timestamp(TimeUnit::Microsecond, None), true, ))), @@ -4048,11 +4165,10 @@ mod tests { let expected = cast( &cast( &list_array, - &DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), ) .unwrap(), - &DataType::List(Arc::new(Field::new( - "item", + &DataType::List(Arc::new(Field::new_list_field( DataType::Timestamp(TimeUnit::Microsecond, None), true, ))), @@ -5146,41 +5262,43 @@ mod tests { assert_eq!("2018-12-25T00:00:00", c.value(1)); } + macro_rules! assert_cast_timestamp_to_string { + ($array:expr, $datatype:expr, $output_array_type: ty, $expected:expr) => {{ + let out = cast(&$array, &$datatype).unwrap(); + let actual = out + .as_any() + .downcast_ref::<$output_array_type>() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(actual, $expected); + }}; + ($array:expr, $datatype:expr, $output_array_type: ty, $options:expr, $expected:expr) => {{ + let out = cast_with_options(&$array, &$datatype, &$options).unwrap(); + let actual = out + .as_any() + .downcast_ref::<$output_array_type>() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(actual, $expected); + }}; + } + #[test] fn test_cast_timestamp_to_strings() { // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None let array = TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); - let out = cast(&array, &DataType::Utf8).unwrap(); - let out = out - .as_any() - .downcast_ref::() - .unwrap() - .into_iter() - .collect::>(); - assert_eq!( - out, - vec![ - Some("1997-05-19T00:00:03.005"), - Some("2018-12-25T00:00:02.001"), - None - ] - ); - let out = cast(&array, &DataType::LargeUtf8).unwrap(); - let out = out - .as_any() - .downcast_ref::() - .unwrap() - .into_iter() - .collect::>(); - assert_eq!( - out, - vec![ - Some("1997-05-19T00:00:03.005"), - Some("2018-12-25T00:00:02.001"), - None - ] - ); + let expected = vec![ + Some("1997-05-19T00:00:03.005"), + Some("2018-12-25T00:00:02.001"), + None, + ]; + + assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast_timestamp_to_string!(array, DataType::Utf8, StringArray, expected); + assert_cast_timestamp_to_string!(array, DataType::LargeUtf8, LargeStringArray, expected); } #[test] @@ -5193,73 +5311,65 @@ mod tests { .with_timestamp_format(Some(ts_format)) .with_timestamp_tz_format(Some(ts_format)), }; + // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None let array_without_tz = TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); - let out = cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap(); - let out = out - .as_any() - .downcast_ref::() - .unwrap() - .into_iter() - .collect::>(); - assert_eq!( - out, - vec![ - Some("1997-05-19 00:00:03.005000"), - Some("2018-12-25 00:00:02.001000"), - None - ] + let expected = vec![ + Some("1997-05-19 00:00:03.005000"), + Some("2018-12-25 00:00:02.001000"), + None, + ]; + assert_cast_timestamp_to_string!( + array_without_tz, + DataType::Utf8View, + StringViewArray, + cast_options, + expected ); - let out = - cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options).unwrap(); - let out = out - .as_any() - .downcast_ref::() - .unwrap() - .into_iter() - .collect::>(); - assert_eq!( - out, - vec![ - Some("1997-05-19 00:00:03.005000"), - Some("2018-12-25 00:00:02.001000"), - None - ] + assert_cast_timestamp_to_string!( + array_without_tz, + DataType::Utf8, + StringArray, + cast_options, + expected + ); + assert_cast_timestamp_to_string!( + array_without_tz, + DataType::LargeUtf8, + LargeStringArray, + cast_options, + expected ); let array_with_tz = TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]) .with_timezone(tz.to_string()); - let out = cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap(); - let out = out - .as_any() - .downcast_ref::() - .unwrap() - .into_iter() - .collect::>(); - assert_eq!( - out, - vec![ - Some("1997-05-19 05:45:03.005000"), - Some("2018-12-25 05:45:02.001000"), - None - ] + let expected = vec![ + Some("1997-05-19 05:45:03.005000"), + Some("2018-12-25 05:45:02.001000"), + None, + ]; + assert_cast_timestamp_to_string!( + array_with_tz, + DataType::Utf8View, + StringViewArray, + cast_options, + expected ); - let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options).unwrap(); - let out = out - .as_any() - .downcast_ref::() - .unwrap() - .into_iter() - .collect::>(); - assert_eq!( - out, - vec![ - Some("1997-05-19 05:45:03.005000"), - Some("2018-12-25 05:45:02.001000"), - None - ] + assert_cast_timestamp_to_string!( + array_with_tz, + DataType::Utf8, + StringArray, + cast_options, + expected + ); + assert_cast_timestamp_to_string!( + array_with_tz, + DataType::LargeUtf8, + LargeStringArray, + cast_options, + expected ); } @@ -7085,12 +7195,12 @@ mod tests { cast_from_null_to_other(&data_type); // Cast null from and to list - let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); cast_from_null_to_other(&data_type); - let data_type = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, true))); cast_from_null_to_other(&data_type); let data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 4); cast_from_null_to_other(&data_type); // Cast null from and to dictionary @@ -7207,11 +7317,11 @@ mod tests { assert_eq!(actual.data_type(), to_array.data_type()); let invalid_target = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Binary, true)), 2); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Binary, true)), 2); assert!(!can_cast_types(from_array.data_type(), &invalid_target)); let invalid_size = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 5); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Float16, true)), 5); assert!(!can_cast_types(from_array.data_type(), &invalid_size)); } @@ -7364,7 +7474,7 @@ mod tests { [(Some([Some(5)]))], 1, )) as ArrayRef; - let to_field_inner = Arc::new(Field::new("item", DataType::Float32, false)); + let to_field_inner = Arc::new(Field::new_list_field(DataType::Float32, false)); let to_field = Arc::new(Field::new( "dummy", DataType::FixedSizeList(to_field_inner.clone(), 1), @@ -7454,7 +7564,7 @@ mod tests { // 4. Nulls that are correctly sized (same as target list size) // Non-null case - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let values = vec![ Some(vec![Some(1), Some(2), Some(3)]), Some(vec![Some(4), Some(5), Some(6)]), @@ -7530,7 +7640,7 @@ mod tests { let res = cast_with_options( array.as_ref(), - &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 3), &CastOptions { safe: false, ..Default::default() @@ -7544,7 +7654,7 @@ mod tests { // too short and truncate lists that are too long. let res = cast( array.as_ref(), - &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 3), ) .unwrap(); let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( @@ -7566,7 +7676,7 @@ mod tests { ])) as ArrayRef; let res = cast_with_options( array.as_ref(), - &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 3), &CastOptions { safe: false, ..Default::default() @@ -7591,7 +7701,7 @@ mod tests { )) as ArrayRef; let actual = cast( array.as_ref(), - &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 2), + &DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 2), ) .unwrap(); assert_eq!(expected.as_ref(), actual.as_ref()); @@ -7614,14 +7724,14 @@ mod tests { )) as ArrayRef; let actual = cast( array.as_ref(), - &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 2), + &DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int64, true)), 2), ) .unwrap(); assert_eq!(expected.as_ref(), actual.as_ref()); let res = cast_with_options( array.as_ref(), - &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int16, true)), 2), + &DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int16, true)), 2), &CastOptions { safe: false, ..Default::default() @@ -7633,7 +7743,7 @@ mod tests { #[test] fn test_cast_list_to_fsl_empty() { - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let array = new_empty_array(&DataType::List(field.clone())); let target_type = DataType::FixedSizeList(field.clone(), 3); @@ -7656,7 +7766,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -7680,7 +7790,7 @@ mod tests { // Construct a list array from the above two let list_data_type = - DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -7699,7 +7809,7 @@ mod tests { .unwrap(); let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 4); let list_data = ArrayData::builder(list_data_type) .len(2) .add_child_data(value_data) @@ -7717,7 +7827,7 @@ mod tests { .unwrap(); let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 4); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int64, true)), 4); let list_data = ArrayData::builder(list_data_type) .len(2) .add_child_data(value_data) @@ -7979,7 +8089,7 @@ mod tests { let array1 = make_list_array().slice(1, 2); let array2 = Arc::new(make_list_array()) as ArrayRef; - let dt = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let dt = DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, true))); let out1 = cast(&array1, &dt).unwrap(); let out2 = cast(&array2, &dt).unwrap(); @@ -7992,7 +8102,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); let value_data = str_array.into_data(); - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -8354,7 +8464,7 @@ mod tests { let input_type = DataType::Decimal128(10, 3); let output_type = DataType::Decimal256(10, 5); assert!(can_cast_types(&input_type, &output_type)); - let array = vec![Some(i128::MAX), Some(i128::MIN)]; + let array = vec![Some(123456), Some(-123456)]; let input_decimal_array = create_decimal_array(array, 10, 3).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; @@ -8364,8 +8474,8 @@ mod tests { Decimal256Array, &output_type, vec![ - Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)), - Some(i256::from_i128(i128::MIN).mul_wrapping(hundred)) + Some(i256::from_i128(123456).mul_wrapping(hundred)), + Some(i256::from_i128(-123456).mul_wrapping(hundred)) ] ); } @@ -9114,7 +9224,31 @@ mod tests { } #[test] - fn test_cast_decimal_to_utf8() { + fn test_cast_decimal_to_string() { + assert!(can_cast_types( + &DataType::Decimal128(10, 4), + &DataType::Utf8View + )); + assert!(can_cast_types( + &DataType::Decimal256(38, 10), + &DataType::Utf8View + )); + + macro_rules! assert_decimal_values { + ($array:expr) => { + let c = $array; + assert_eq!("1123.454", c.value(0)); + assert_eq!("2123.456", c.value(1)); + assert_eq!("-3123.453", c.value(2)); + assert_eq!("-3123.456", c.value(3)); + assert_eq!("0.000", c.value(4)); + assert_eq!("0.123", c.value(5)); + assert_eq!("1234.567", c.value(6)); + assert_eq!("-1234.567", c.value(7)); + assert!(c.is_null(8)); + }; + } + fn test_decimal_to_string( output_type: DataType, array: PrimitiveArray, @@ -9122,18 +9256,19 @@ mod tests { let b = cast(&array, &output_type).unwrap(); assert_eq!(b.data_type(), &output_type); - let c = b.as_string::(); - - assert_eq!("1123.454", c.value(0)); - assert_eq!("2123.456", c.value(1)); - assert_eq!("-3123.453", c.value(2)); - assert_eq!("-3123.456", c.value(3)); - assert_eq!("0.000", c.value(4)); - assert_eq!("0.123", c.value(5)); - assert_eq!("1234.567", c.value(6)); - assert_eq!("-1234.567", c.value(7)); - assert!(c.is_null(8)); + match b.data_type() { + DataType::Utf8View => { + let c = b.as_string_view(); + assert_decimal_values!(c); + } + DataType::Utf8 | DataType::LargeUtf8 => { + let c = b.as_string::(); + assert_decimal_values!(c); + } + _ => (), + } } + let array128: Vec> = vec![ Some(1123454), Some(2123456), @@ -9145,22 +9280,33 @@ mod tests { Some(-123456789), None, ]; + let array256: Vec> = array128 + .iter() + .map(|num| num.map(i256::from_i128)) + .collect(); - let array256: Vec> = array128.iter().map(|v| v.map(i256::from_i128)).collect(); - - test_decimal_to_string::( + test_decimal_to_string::( + DataType::Utf8View, + create_decimal_array(array128.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( DataType::Utf8, create_decimal_array(array128.clone(), 7, 3).unwrap(), ); - test_decimal_to_string::( + test_decimal_to_string::( DataType::LargeUtf8, create_decimal_array(array128, 7, 3).unwrap(), ); - test_decimal_to_string::( + + test_decimal_to_string::( + DataType::Utf8View, + create_decimal256_array(array256.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( DataType::Utf8, create_decimal256_array(array256.clone(), 7, 3).unwrap(), ); - test_decimal_to_string::( + test_decimal_to_string::( DataType::LargeUtf8, create_decimal256_array(array256, 7, 3).unwrap(), ); @@ -9793,4 +9939,98 @@ mod tests { "Cast non-nullable to non-nullable struct field returning null should fail", ); } + + #[test] + fn test_cast_struct_to_non_struct() { + let boolean = Arc::new(BooleanArray::from(vec![true, false])); + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + )]); + let to_type = DataType::Utf8; + let result = cast(&struct_array, &to_type); + assert_eq!( + r#"Cast error: Casting from Struct([Field { name: "a", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]) to Utf8 not supported"#, + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_non_struct_to_struct() { + let array = StringArray::from(vec!["a", "b"]); + let to_type = DataType::Struct(vec![Field::new("a", DataType::Boolean, false)].into()); + let result = cast(&array, &to_type); + assert_eq!( + r#"Cast error: Casting from Utf8 to Struct([Field { name: "a", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]) not supported"#, + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_same_scale() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + let input_type = DataType::Decimal128(24, 2); + let output_type = DataType::Decimal128(6, 2); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999"); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_lower_scale() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 4).unwrap(); + let input_type = DataType::Decimal128(24, 4); + let output_type = DataType::Decimal128(6, 2); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 1234568 is too large to store in a Decimal128 of precision 6. Max is 999999"); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_greater_scale() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + let input_type = DataType::Decimal128(24, 2); + let output_type = DataType::Decimal128(6, 3); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 1234567890 is too large to store in a Decimal128 of precision 6. Max is 999999"); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_diff_type() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + let input_type = DataType::Decimal128(24, 2); + let output_type = DataType::Decimal256(6, 2); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 123456789 is too large to store in a Decimal256 of precision 6. Max is 999999"); + } } diff --git a/arrow-cast/src/cast/string.rs b/arrow-cast/src/cast/string.rs index 7d0e7e21c859..7f22c4fd64de 100644 --- a/arrow-cast/src/cast/string.rs +++ b/arrow-cast/src/cast/string.rs @@ -38,6 +38,30 @@ pub(crate) fn value_to_string( Ok(Arc::new(builder.finish())) } +pub(crate) fn value_to_string_view( + array: &dyn Array, + options: &CastOptions, +) -> Result { + let mut builder = StringViewBuilder::with_capacity(array.len()); + let formatter = ArrayFormatter::try_new(array, &options.format_options)?; + let nulls = array.nulls(); + // buffer to avoid reallocating on each value + // TODO: replace with write to builder after https://github.com/apache/arrow-rs/issues/6373 + let mut buffer = String::new(); + for i in 0..array.len() { + match nulls.map(|x| x.is_null(i)).unwrap_or_default() { + true => builder.append_null(), + false => { + // write to buffer first and then copy into target array + buffer.clear(); + formatter.value(i).write(&mut buffer)?; + builder.append_value(&buffer) + } + } + } + Ok(Arc::new(builder.finish())) +} + /// Parse UTF-8 pub(crate) fn parse_string( array: &dyn Array, @@ -344,19 +368,14 @@ pub(crate) fn cast_binary_to_string( } } -/// Casts Utf8 to Boolean -pub(crate) fn cast_utf8_to_boolean( - from: &dyn Array, +/// Casts string to boolean +fn cast_string_to_boolean<'a, StrArray>( + array: &StrArray, cast_options: &CastOptions, ) -> Result where - OffsetSize: OffsetSizeTrait, + StrArray: StringArrayType<'a>, { - let array = from - .as_any() - .downcast_ref::>() - .unwrap(); - let output_array = array .iter() .map(|value| match value { @@ -378,3 +397,27 @@ where Ok(Arc::new(output_array)) } + +pub(crate) fn cast_utf8_to_boolean( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + cast_string_to_boolean(&array, cast_options) +} + +pub(crate) fn cast_utf8view_to_boolean( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = from.as_any().downcast_ref::().unwrap(); + + cast_string_to_boolean(&array, cast_options) +} diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index 4bd94c13fe8d..4e93e9787cc8 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -497,6 +497,10 @@ parser_primitive!(Int64Type); parser_primitive!(Int32Type); parser_primitive!(Int16Type); parser_primitive!(Int8Type); +parser_primitive!(DurationNanosecondType); +parser_primitive!(DurationMicrosecondType); +parser_primitive!(DurationMillisecondType); +parser_primitive!(DurationSecondType); impl Parser for TimestampNanosecondType { fn parse(string: &str) -> Option { @@ -877,7 +881,7 @@ pub fn parse_decimal( for (_, b) in bs.by_ref() { if !b.is_ascii_digit() { if *b == b'e' || *b == b'E' { - result = match parse_e_notation::( + result = parse_e_notation::( s, digits as u16, fractionals as i16, @@ -885,10 +889,7 @@ pub fn parse_decimal( point_index, precision as u16, scale as i16, - ) { - Err(e) => return Err(e), - Ok(v) => v, - }; + )?; is_e_notation = true; @@ -922,7 +923,7 @@ pub fn parse_decimal( } } b'e' | b'E' => { - result = match parse_e_notation::( + result = parse_e_notation::( s, digits as u16, fractionals as i16, @@ -930,10 +931,7 @@ pub fn parse_decimal( index, precision as u16, scale as i16, - ) { - Err(e) => return Err(e), - Ok(v) => v, - }; + )?; is_e_notation = true; diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs index 4a3cbda283a5..ad3b952c327d 100644 --- a/arrow-cast/src/pretty.rs +++ b/arrow-cast/src/pretty.rs @@ -296,7 +296,7 @@ mod tests { fn test_pretty_format_fixed_size_list() { // define a schema. let field_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let keys_builder = Int32Array::builder(3); diff --git a/arrow-csv/Cargo.toml b/arrow-csv/Cargo.toml index be213c9363c2..8823924eb55b 100644 --- a/arrow-csv/Cargo.toml +++ b/arrow-csv/Cargo.toml @@ -35,18 +35,16 @@ bench = false [dependencies] arrow-array = { workspace = true } -arrow-buffer = { workspace = true } arrow-cast = { workspace = true } -arrow-data = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } csv = { version = "1.1", default-features = false } csv-core = { version = "0.1" } lazy_static = { version = "1.4", default-features = false } -lexical-core = { version = "1.0", default-features = false } regex = { version = "1.7.0", default-features = false, features = ["std", "unicode", "perf"] } [dev-dependencies] +arrow-buffer = { workspace = true } tempfile = "3.3" futures = "0.3" tokio = { version = "1.27", default-features = false, features = ["io-util"] } diff --git a/arrow-csv/LICENSE.txt b/arrow-csv/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-csv/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-csv/NOTICE.txt b/arrow-csv/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-csv/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index c91b436f6cce..d3d518316397 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -136,7 +136,7 @@ use lazy_static::lazy_static; use regex::{Regex, RegexSet}; use std::fmt::{self, Debug}; use std::fs::File; -use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; +use std::io::{BufRead, BufReader as StdBufReader, Read}; use std::sync::Arc; use crate::map_csv_error; @@ -241,7 +241,7 @@ pub struct Format { } impl Format { - /// Specify whether the CSV file has a header, defaults to `true` + /// Specify whether the CSV file has a header, defaults to `false` /// /// When `true`, the first row of the CSV file is treated as a header row pub fn with_header(mut self, has_header: bool) -> Self { @@ -399,51 +399,6 @@ impl Format { } } -/// Infer the schema of a CSV file by reading through the first n records of the file, -/// with `max_read_records` controlling the maximum number of records to read. -/// -/// If `max_read_records` is not set, the whole file is read to infer its schema. -/// -/// Return inferred schema and number of records used for inference. This function does not change -/// reader cursor offset. -/// -/// The inferred schema will always have each field set as nullable. -#[deprecated(note = "Use Format::infer_schema")] -#[allow(deprecated)] -pub fn infer_file_schema( - mut reader: R, - delimiter: u8, - max_read_records: Option, - has_header: bool, -) -> Result<(Schema, usize), ArrowError> { - let saved_offset = reader.stream_position()?; - let r = infer_reader_schema(&mut reader, delimiter, max_read_records, has_header)?; - // return the reader seek back to the start - reader.seek(SeekFrom::Start(saved_offset))?; - Ok(r) -} - -/// Infer schema of CSV records provided by struct that implements `Read` trait. -/// -/// `max_read_records` controlling the maximum number of records to read. If `max_read_records` is -/// not set, all records are read to infer the schema. -/// -/// Return inferred schema and number of records used for inference. -#[deprecated(note = "Use Format::infer_schema")] -pub fn infer_reader_schema( - reader: R, - delimiter: u8, - max_read_records: Option, - has_header: bool, -) -> Result<(Schema, usize), ArrowError> { - let format = Format { - delimiter: Some(delimiter), - header: has_header, - ..Default::default() - }; - format.infer_schema(reader, max_read_records) -} - /// Infer schema from a list of CSV files by reading through first n records /// with `max_read_records` controlling the maximum number of records to read. /// @@ -824,42 +779,66 @@ fn parse( match key_type.as_ref() { DataType::Int8 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::Int16 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::Int32 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::Int64 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt8 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt16 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt32 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt64 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), _ => Err(ArrowError::ParseError(format!( @@ -1101,14 +1080,6 @@ impl ReaderBuilder { } } - /// Set whether the CSV file has headers - #[deprecated(note = "Use with_header")] - #[doc(hidden)] - pub fn has_header(mut self, has_header: bool) -> Self { - self.format.header = has_header; - self - } - /// Set whether the CSV file has a header pub fn with_header(mut self, has_header: bool) -> Self { self.format.header = has_header; @@ -1236,7 +1207,7 @@ impl ReaderBuilder { mod tests { use super::*; - use std::io::{Cursor, Write}; + use std::io::{Cursor, Seek, SeekFrom, Write}; use tempfile::NamedTempFile; use arrow_array::cast::AsArray; @@ -1528,6 +1499,40 @@ mod tests { assert_eq!(strings.value(29), "Uckfield, East Sussex, UK"); } + #[test] + fn test_csv_with_nullable_dictionary() { + let offset_type = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + ]; + for data_type in offset_type { + let file = File::open("test/data/dictionary_nullable_test.csv").unwrap(); + let dictionary_type = + DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", dictionary_type.clone(), true), + ])); + + let mut csv = ReaderBuilder::new(schema) + .build(file.try_clone().unwrap()) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + assert_eq!(3, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + + let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap(); + assert!(!names.is_null(2)); + assert!(names.is_null(1)); + } + } #[test] fn test_nulls() { let schema = Arc::new(Schema::new(vec![ diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index eae2133a4623..c5a0a0b76d59 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -256,14 +256,6 @@ impl WriterBuilder { Self::default() } - /// Set whether to write headers - #[deprecated(note = "Use Self::with_header")] - #[doc(hidden)] - pub fn has_headers(mut self, has_headers: bool) -> Self { - self.has_header = has_headers; - self - } - /// Set whether to write the CSV file with a header pub fn with_header(mut self, header: bool) -> Self { self.has_header = header; @@ -397,17 +389,6 @@ impl WriterBuilder { self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE) } - /// Use RFC3339 format for date/time/timestamps (default) - #[deprecated(note = "Use WriterBuilder::default()")] - pub fn with_rfc3339(mut self) -> Self { - self.date_format = None; - self.datetime_format = None; - self.time_format = None; - self.timestamp_format = None; - self.timestamp_tz_format = None; - self - } - /// Create a new `Writer` pub fn build(self, writer: W) -> Writer { let mut builder = csv::WriterBuilder::new(); diff --git a/arrow-csv/test/data/dictionary_nullable_test.csv b/arrow-csv/test/data/dictionary_nullable_test.csv new file mode 100644 index 000000000000..c9ada5293b70 --- /dev/null +++ b/arrow-csv/test/data/dictionary_nullable_test.csv @@ -0,0 +1,3 @@ +id,name +1, +2,bob diff --git a/arrow-data/LICENSE.txt b/arrow-data/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-data/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-data/NOTICE.txt b/arrow-data/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-data/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 8af2a91cf159..a35b5e8629e9 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -30,11 +30,6 @@ use std::sync::Arc; use crate::{equal, validate_binary_view, validate_string_view}; -/// A collection of [`Buffer`] -#[doc(hidden)] -#[deprecated(note = "Use [Buffer]")] -pub type Buffers<'a> = &'a [Buffer]; - #[inline] pub(crate) fn contains_nulls( null_bit_buffer: Option<&NullBuffer>, diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 702cb1360c2d..fbb295036a9b 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -43,11 +43,11 @@ base64 = { version = "0.22", default-features = false, features = ["std"] } bytes = { version = "1", default-features = false } futures = { version = "0.3", default-features = false, features = ["alloc"] } once_cell = { version = "1", optional = true } -paste = { version = "1.0" } +paste = { version = "1.0" , optional = true } prost = { version = "0.13.1", default-features = false, features = ["prost-derive"] } # For Timestamp type prost-types = { version = "0.13.1", default-features = false } -tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } +tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"], optional = true } tonic = { version = "0.12.3", default-features = false, features = ["transport", "codegen", "prost"] } # CLI-related dependencies @@ -61,11 +61,10 @@ all-features = true [features] default = [] -flight-sql-experimental = ["arrow-arith", "arrow-data", "arrow-ord", "arrow-row", "arrow-select", "arrow-string", "once_cell"] +flight-sql-experimental = ["dep:arrow-arith", "dep:arrow-data", "dep:arrow-ord", "dep:arrow-row", "dep:arrow-select", "dep:arrow-string", "dep:once_cell", "dep:paste"] tls = ["tonic/tls"] - # Enable CLI tools -cli = ["anyhow", "arrow-array/chrono-tz", "arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "tonic/tls-webpki-roots"] +cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber"] [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } @@ -75,6 +74,9 @@ http-body = "1.0.0" hyper-util = "0.1" pin-project-lite = "0.2" tempfile = "3.3" +tracing-log = { version = "0.2" } +tracing-subscriber = { version = "0.3.1", default-features = false, features = ["ansi", "env-filter", "fmt"] } +tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } tokio-stream = { version = "0.1", features = ["net"] } tower = { version = "0.5.0", features = ["util"] } uuid = { version = "1.10.0", features = ["v4"] } diff --git a/arrow-flight/LICENSE.txt b/arrow-flight/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-flight/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-flight/NOTICE.txt b/arrow-flight/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-flight/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-flight/README.md b/arrow-flight/README.md index df74bc012a1c..3ffc8780c2f8 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -31,14 +31,14 @@ Add this to your Cargo.toml: ```toml [dependencies] -arrow-flight = "53.2.0" +arrow-flight = "54.0.0" ``` Apache Arrow Flight is a gRPC based protocol for exchanging Arrow data between processes. See the blog post [Introducing Apache Arrow Flight: A Framework for Fast Data Transport](https://arrow.apache.org/blog/2019/10/13/introducing-arrow-flight/) for more information. This crate provides a Rust implementation of the [Flight.proto](../format/Flight.proto) gRPC protocol and -[examples](https://github.com/apache/arrow-rs/tree/master/arrow-flight/examples) +[examples](https://github.com/apache/arrow-rs/tree/main/arrow-flight/examples) that demonstrate how to build a Flight server implemented with [tonic](https://docs.rs/crate/tonic/latest). ## Feature Flags diff --git a/arrow-flight/gen/Cargo.toml b/arrow-flight/gen/Cargo.toml index c7fe89beb93a..6358227a8912 100644 --- a/arrow-flight/gen/Cargo.toml +++ b/arrow-flight/gen/Cargo.toml @@ -32,6 +32,5 @@ publish = false [dependencies] # Pin specific version of the tonic-build dependencies to avoid auto-generated # (and checked in) arrow.flight.protocol.rs from changing -proc-macro2 = { version = "=1.0.89", default-features = false } -prost-build = { version = "=0.13.3", default-features = false } +prost-build = { version = "=0.13.4", default-features = false } tonic-build = { version = "=0.12.3", default-features = false, features = ["transport", "prost"] } diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index 7bafc384306b..760fc926fca6 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -295,7 +295,7 @@ impl FlightDataDecoder { )); }; - let buffer = Buffer::from_bytes(data.data_body.into()); + let buffer = Buffer::from(data.data_body); let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| { FlightError::protocol( "Could not get dictionary batch from DictionaryBatch message", diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index ec4fe323b267..19fe42474405 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -535,8 +535,10 @@ fn prepare_field_for_flight( ) .with_metadata(field.metadata().clone()) } else { + #[allow(deprecated)] let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + #[allow(deprecated)] Field::new_dict( field.name(), field.data_type().clone(), @@ -583,7 +585,9 @@ fn prepare_schema_for_flight( ) .with_metadata(field.metadata().clone()) } else { + #[allow(deprecated)] let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + #[allow(deprecated)] Field::new_dict( field.name(), field.data_type().clone(), @@ -650,10 +654,12 @@ struct FlightIpcEncoder { impl FlightIpcEncoder { fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { + #[allow(deprecated)] let preserve_dict_id = options.preserve_dict_id(); Self { options, data_gen: IpcDataGenerator::default(), + #[allow(deprecated)] dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( error_on_replacement, preserve_dict_id, @@ -934,7 +940,7 @@ mod tests { let mut decoder = FlightDataDecoder::new(encoder); let expected_schema = Schema::new(vec![Field::new_list( "dict_list", - Field::new("item", DataType::Utf8, true), + Field::new_list_field(DataType::Utf8, true), true, )]); @@ -1038,7 +1044,7 @@ mod tests { "struct", vec![Field::new_list( "dict_list", - Field::new("item", DataType::Utf8, true), + Field::new_list_field(DataType::Utf8, true), true, )], true, @@ -1087,12 +1093,16 @@ mod tests { ))], ); - struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("a"), None, Some("b")]); + struct_builder.field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("a"), None, Some("b")]); struct_builder.append(true); let arr1 = struct_builder.finish(); - struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("c"), None, Some("d")]); + struct_builder.field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("c"), None, Some("d")]); struct_builder.append(true); let arr2 = struct_builder.finish(); @@ -1214,12 +1224,16 @@ mod tests { let hydrated_struct_fields = vec![Field::new_list( "dict_list", - Field::new("item", DataType::Utf8, true), + Field::new_list_field(DataType::Utf8, true), true, )]; let hydrated_union_fields = vec![ - Field::new_list("dict_list", Field::new("item", DataType::Utf8, true), true), + Field::new_list( + "dict_list", + Field::new_list_field(DataType::Utf8, true), + true, + ), Field::new_struct("struct", hydrated_struct_fields.clone(), true), Field::new("string", DataType::Utf8, true), ]; @@ -1300,6 +1314,11 @@ mod tests { .into_iter() .collect::(); + let mut field_types = union_fields.iter().map(|(_, field)| field.data_type()); + let dict_list_ty = field_types.next().unwrap(); + let struct_ty = field_types.next().unwrap(); + let string_ty = field_types.next().unwrap(); + let struct_fields = vec![Field::new_list( "dict_list", Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), @@ -1318,9 +1337,9 @@ mod tests { type_id_buffer, None, vec![ - Arc::new(arr1) as Arc, - new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), - new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + Arc::new(arr1), + new_null_array(struct_ty, 1), + new_null_array(string_ty, 1), ], ) .unwrap(); @@ -1336,9 +1355,9 @@ mod tests { type_id_buffer, None, vec![ - new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + new_null_array(dict_list_ty, 1), Arc::new(arr2), - new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + new_null_array(string_ty, 1), ], ) .unwrap(); @@ -1349,8 +1368,8 @@ mod tests { type_id_buffer, None, vec![ - new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), - new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + new_null_array(dict_list_ty, 1), + new_null_array(struct_ty, 1), Arc::new(StringArray::from(vec!["e"])), ], ) @@ -1528,6 +1547,7 @@ mod tests { async fn verify_flight_round_trip(mut batches: Vec) { let expected_schema = batches.first().unwrap().schema(); + #[allow(deprecated)] let encoder = FlightDataEncoderBuilder::default() .with_options(IpcWriteOptions::default().with_preserve_dict_id(false)) .with_dictionary_handling(DictionaryHandling::Resend) @@ -1555,6 +1575,7 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); + #[allow(deprecated)] let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false); @@ -1573,12 +1594,21 @@ mod tests { hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize"); } - pub fn make_flight_data( + fn make_flight_data( + batch: &RecordBatch, + options: &IpcWriteOptions, + ) -> (Vec, FlightData) { + flight_data_from_arrow_batch(batch, options) + } + + fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = IpcDataGenerator::default(); - let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); + #[allow(deprecated)] + let mut dictionary_tracker = + DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) @@ -1741,7 +1771,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); - verify_encoded_split(batch, 160).await; + verify_encoded_split(batch, 48).await; } #[tokio::test] @@ -1803,7 +1833,7 @@ mod tests { .flight_descriptor .as_ref() .map(|descriptor| { - let path_len: usize = descriptor.path.iter().map(|p| p.as_bytes().len()).sum(); + let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum(); std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len }) diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 9f18416c06ec..1dd2700794f3 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -38,6 +38,8 @@ //! [Flight SQL]: https://arrow.apache.org/docs/format/FlightSql.html #![allow(rustdoc::invalid_html_tags)] #![warn(missing_docs)] +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] use arrow_ipc::{convert, writer, writer::EncodedData, writer::IpcWriteOptions}; use arrow_schema::{ArrowError, Schema}; @@ -141,6 +143,7 @@ pub struct IpcMessage(pub Bytes); fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData { let data_gen = writer::IpcDataGenerator::default(); + #[allow(deprecated)] let mut dict_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); data_gen.schema_to_bytes_with_dictionary_tracker(arrow_schema, &mut dict_tracker, options) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index e45e505b2b61..6d3ac3dbe610 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -228,8 +228,8 @@ impl FlightSqlServiceClient { .await .map_err(status_to_arrow_error)? .unwrap(); - let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; - let result: DoPutUpdateResult = any.unpack()?.unwrap(); + let result: DoPutUpdateResult = + Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; Ok(result.record_count) } @@ -274,8 +274,8 @@ impl FlightSqlServiceClient { .await .map_err(status_to_arrow_error)? .unwrap(); - let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; - let result: DoPutUpdateResult = any.unpack()?.unwrap(); + let result: DoPutUpdateResult = + Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; Ok(result.record_count) } @@ -593,8 +593,8 @@ impl PreparedStatement { .await .map_err(status_to_arrow_error)? .unwrap(); - let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; - let result: DoPutUpdateResult = any.unpack()?.unwrap(); + let result: DoPutUpdateResult = + Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; Ok(result.record_count) } @@ -721,7 +721,7 @@ pub fn arrow_data_from_flight_data( let dictionaries_by_field = HashMap::new(); let record_batch = read_record_batch( - &Buffer::from_bytes(flight_data.data_body.into()), + &Buffer::from(flight_data.data_body), ipc_record_batch, arrow_schema_ref.clone(), &dictionaries_by_field, diff --git a/arrow-flight/src/sql/metadata/catalogs.rs b/arrow-flight/src/sql/metadata/catalogs.rs index 327fed81077b..e27c63c3932f 100644 --- a/arrow-flight/src/sql/metadata/catalogs.rs +++ b/arrow-flight/src/sql/metadata/catalogs.rs @@ -68,7 +68,8 @@ impl GetCatalogsBuilder { /// builds a `RecordBatch` with the correct schema for a /// [`CommandGetCatalogs`] response pub fn build(self) -> Result { - let Self { catalogs } = self; + let Self { mut catalogs } = self; + catalogs.sort_unstable(); let batch = RecordBatch::try_new( Arc::clone(&GET_CATALOG_SCHEMA), @@ -98,3 +99,30 @@ static GET_CATALOG_SCHEMA: Lazy = Lazy::new(|| { false, )])) }); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_catalogs_are_sorted() { + let batch = ["a_catalog", "c_catalog", "b_catalog"] + .into_iter() + .fold(GetCatalogsBuilder::new(), |mut builder, catalog| { + builder.append(catalog); + builder + }) + .build() + .unwrap(); + let catalogs = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .flatten() + .collect::>(); + assert!(catalogs.is_sorted()); + assert_eq!(catalogs, ["a_catalog", "b_catalog", "c_catalog"]); + } +} diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs index 2ea30df7fc2f..58b228530942 100644 --- a/arrow-flight/src/sql/metadata/sql_info.rs +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -172,7 +172,7 @@ static UNION_TYPE: Lazy = Lazy::new(|| { // treat list as nullable b/c that is what the builders make Field::new( "string_list", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), true, ), Field::new( @@ -184,7 +184,7 @@ static UNION_TYPE: Lazy = Lazy::new(|| { Field::new("keys", DataType::Int32, false), Field::new( "values", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, ), ])), diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs index 485bedaebfb0..a3a18ca10888 100644 --- a/arrow-flight/src/sql/metadata/xdbc_info.rs +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -330,7 +330,7 @@ static GET_XDBC_INFO_SCHEMA: Lazy = Lazy::new(|| { Field::new("literal_suffix", DataType::Utf8, true), Field::new( "create_params", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, false))), true, ), Field::new("nullable", DataType::Int32, false), diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 6b9befa63600..8ab8a16dbb50 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -719,7 +719,7 @@ where let record_count = self.do_put_statement_update(command, request).await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(PutResult { - app_metadata: result.as_any().encode_to_vec().into(), + app_metadata: result.encode_to_vec().into(), })]); Ok(Response::new(Box::pin(output))) } @@ -727,7 +727,7 @@ where let record_count = self.do_put_statement_ingest(command, request).await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(PutResult { - app_metadata: result.as_any().encode_to_vec().into(), + app_metadata: result.encode_to_vec().into(), })]); Ok(Response::new(Box::pin(output))) } @@ -744,7 +744,7 @@ where let record_count = self.do_put_substrait_plan(command, request).await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(PutResult { - app_metadata: result.as_any().encode_to_vec().into(), + app_metadata: result.encode_to_vec().into(), })]); Ok(Response::new(Box::pin(output))) } @@ -754,7 +754,7 @@ where .await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(PutResult { - app_metadata: result.as_any().encode_to_vec().into(), + app_metadata: result.encode_to_vec().into(), })]); Ok(Response::new(Box::pin(output))) } diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index f6129ddfe248..428dde73ca6c 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -17,8 +17,7 @@ //! Utilities to assist with reading and writing Arrow data as Flight messages -use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult}; -use bytes::Bytes; +use crate::{FlightData, SchemaAsIpc}; use std::collections::HashMap; use std::sync::Arc; @@ -28,30 +27,6 @@ use arrow_ipc::convert::fb_to_schema; use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries -/// and a `FlightData` representing the bytes of the batch's values -#[deprecated( - since = "30.0.0", - note = "Use IpcDataGenerator directly with DictionaryTracker to avoid re-sending dictionaries" -)] -pub fn flight_data_from_arrow_batch( - batch: &RecordBatch, - options: &IpcWriteOptions, -) -> (Vec, FlightData) { - let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); - - let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, options) - .expect("DictionaryTracker configured above to not error on replacement"); - - let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); - let flight_batch = encoded_batch.into(); - - (flight_dictionaries, flight_batch) -} - /// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result, ArrowError> { let schema = flight_data.first().ok_or_else(|| { @@ -104,41 +79,6 @@ pub fn flight_data_to_arrow_batch( })? } -/// Convert a `Schema` to `SchemaResult` by converting to an IPC message -#[deprecated( - since = "4.4.0", - note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).try_into()" -)] -pub fn flight_schema_from_arrow_schema( - schema: &Schema, - options: &IpcWriteOptions, -) -> Result { - SchemaAsIpc::new(schema, options).try_into() -} - -/// Convert a `Schema` to `FlightData` by converting to an IPC message -#[deprecated( - since = "4.4.0", - note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).into()" -)] -pub fn flight_data_from_arrow_schema(schema: &Schema, options: &IpcWriteOptions) -> FlightData { - SchemaAsIpc::new(schema, options).into() -} - -/// Convert a `Schema` to bytes in the format expected in `FlightInfo.schema` -#[deprecated( - since = "4.4.0", - note = "Use TryFrom trait, e.g.: SchemaAsIpc::new(schema, options).try_into()" -)] -pub fn ipc_message_from_arrow_schema( - schema: &Schema, - options: &IpcWriteOptions, -) -> Result { - let message = SchemaAsIpc::new(schema, options).try_into()?; - let IpcMessage(vals) = message; - Ok(vals) -} - /// Convert `RecordBatch`es to wire protocol `FlightData`s pub fn batches_to_flight_data( schema: &Schema, @@ -150,6 +90,7 @@ pub fn batches_to_flight_data( let mut flight_data = vec![]; let data_gen = writer::IpcDataGenerator::default(); + #[allow(deprecated)] let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); diff --git a/arrow-integration-test/LICENSE.txt b/arrow-integration-test/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-integration-test/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-integration-test/NOTICE.txt b/arrow-integration-test/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-integration-test/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-integration-test/src/field.rs b/arrow-integration-test/src/field.rs index 32edc4165938..4b896ed391be 100644 --- a/arrow-integration-test/src/field.rs +++ b/arrow-integration-test/src/field.rs @@ -252,6 +252,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { _ => data_type, }; + #[allow(deprecated)] let mut field = Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered); field.set_metadata(metadata); Ok(field) @@ -274,17 +275,21 @@ pub fn field_to_json(field: &Field) -> serde_json::Value { }; match field.data_type() { - DataType::Dictionary(ref index_type, ref value_type) => serde_json::json!({ - "name": field.name(), - "nullable": field.is_nullable(), - "type": data_type_to_json(value_type), - "children": children, - "dictionary": { - "id": field.dict_id().unwrap(), - "indexType": data_type_to_json(index_type), - "isOrdered": field.dict_is_ordered().unwrap(), - } - }), + DataType::Dictionary(ref index_type, ref value_type) => { + #[allow(deprecated)] + let dict_id = field.dict_id().unwrap(); + serde_json::json!({ + "name": field.name(), + "nullable": field.is_nullable(), + "type": data_type_to_json(value_type), + "children": children, + "dictionary": { + "id": dict_id, + "indexType": data_type_to_json(index_type), + "isOrdered": field.dict_is_ordered().unwrap(), + } + }) + } _ => serde_json::json!({ "name": field.name(), "nullable": field.is_nullable(), diff --git a/arrow-integration-test/src/lib.rs b/arrow-integration-test/src/lib.rs index ea5b545f2e81..f025009c22de 100644 --- a/arrow-integration-test/src/lib.rs +++ b/arrow-integration-test/src/lib.rs @@ -787,6 +787,7 @@ pub fn array_from_json( Ok(Arc::new(array)) } DataType::Dictionary(key_type, value_type) => { + #[allow(deprecated)] let dict_id = field.dict_id().ok_or_else(|| { ArrowError::JsonError(format!("Unable to find dict_id for field {field:?}")) })?; @@ -930,10 +931,12 @@ pub fn dictionary_array_from_json( let null_buf = create_null_buf(&json_col); // build the key data into a buffer, then construct values separately + #[allow(deprecated)] let key_field = Field::new_dict( "key", dict_key.clone(), field.is_nullable(), + #[allow(deprecated)] field .dict_id() .expect("Dictionary fields must have a dict_id value"), @@ -1192,7 +1195,7 @@ mod tests { Field::new("utf8s", DataType::Utf8, true), Field::new( "lists", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, ), Field::new( @@ -1249,7 +1252,7 @@ mod tests { let value_data = Int32Array::from(vec![None, Some(2), None, None]); let value_offsets = Buffer::from_slice_ref([0, 3, 4, 4]); - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) diff --git a/arrow-integration-test/src/schema.rs b/arrow-integration-test/src/schema.rs index 541a1ec746ac..512f0aed8e54 100644 --- a/arrow-integration-test/src/schema.rs +++ b/arrow-integration-test/src/schema.rs @@ -150,7 +150,7 @@ mod tests { Field::new("c21", DataType::Interval(IntervalUnit::MonthDayNano), false), Field::new( "c22", - DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Boolean, true))), false, ), Field::new( @@ -189,6 +189,7 @@ mod tests { Field::new("c30", DataType::Duration(TimeUnit::Millisecond), false), Field::new("c31", DataType::Duration(TimeUnit::Microsecond), false), Field::new("c32", DataType::Duration(TimeUnit::Nanosecond), false), + #[allow(deprecated)] Field::new_dict( "c33", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), diff --git a/arrow-integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml index 7be56d919852..8654b4b92734 100644 --- a/arrow-integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -36,20 +36,17 @@ logging = ["tracing-subscriber"] [dependencies] arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json", "ffi"] } arrow-flight = { path = "../arrow-flight", default-features = false } -arrow-buffer = { path = "../arrow-buffer", default-features = false } arrow-integration-test = { path = "../arrow-integration-test", default-features = false } -async-trait = { version = "0.1.41", default-features = false } clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } futures = { version = "0.3", default-features = false } -hex = { version = "0.4", default-features = false, features = ["std"] } prost = { version = "0.13", default-features = false } serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } -tokio = { version = "1.0", default-features = false } +tokio = { version = "1.0", default-features = false, features = [ "rt-multi-thread"] } tonic = { version = "0.12", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } -num = { version = "0.4", default-features = false, features = ["std"] } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } [dev-dependencies] +arrow-buffer = { path = "../arrow-buffer", default-features = false } tempfile = { version = "3", default-features = false } diff --git a/arrow-integration-testing/README.md b/arrow-integration-testing/README.md index dcf39c27fbc5..86c79f5030ce 100644 --- a/arrow-integration-testing/README.md +++ b/arrow-integration-testing/README.md @@ -53,7 +53,7 @@ pip install -e dev/archery[integration] ### Build the C++ binaries: -Follow the [C++ Direction](https://github.com/apache/arrow/tree/master/docs/source/developers/cpp) and build the integration test binaries with a command like this: +Follow the [C++ Direction](https://github.com/apache/arrow/tree/main/docs/source/developers/cpp) and build the integration test binaries with a command like this: ``` # build cpp binaries diff --git a/arrow-integration-testing/src/bin/arrow-file-to-stream.rs b/arrow-integration-testing/src/bin/arrow-file-to-stream.rs index 3e027faef91f..661f0a047db4 100644 --- a/arrow-integration-testing/src/bin/arrow-file-to-stream.rs +++ b/arrow-integration-testing/src/bin/arrow-file-to-stream.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] + use arrow::error::Result; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::StreamWriter; diff --git a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs index cc3dd2110e36..6a901cc63bab 100644 --- a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs +++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] + use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; diff --git a/arrow-integration-testing/src/bin/arrow-stream-to-file.rs b/arrow-integration-testing/src/bin/arrow-stream-to-file.rs index 07ac5c7ddd42..8b4bb332781c 100644 --- a/arrow-integration-testing/src/bin/arrow-stream-to-file.rs +++ b/arrow-integration-testing/src/bin/arrow-stream-to-file.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] + use std::io; use arrow::error::Result; diff --git a/arrow-integration-testing/src/bin/flight-test-integration-client.rs b/arrow-integration-testing/src/bin/flight-test-integration-client.rs index b8bbb952837b..0d16fe3b403f 100644 --- a/arrow-integration-testing/src/bin/flight-test-integration-client.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-client.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] + use arrow_integration_testing::flight_client_scenarios; use clap::Parser; type Error = Box; diff --git a/arrow-integration-testing/src/bin/flight-test-integration-server.rs b/arrow-integration-testing/src/bin/flight-test-integration-server.rs index 5310d07d4f8e..94be71309799 100644 --- a/arrow-integration-testing/src/bin/flight-test-integration-server.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-server.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] + use arrow_integration_testing::flight_server_scenarios; use clap::Parser; diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c8289ff446a0..406419028d00 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -29,7 +29,7 @@ use arrow::{ }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, SchemaAsIpc, Ticket, + utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; @@ -72,7 +72,20 @@ async fn upload_data( let (mut upload_tx, upload_rx) = mpsc::channel(10); let options = arrow::ipc::writer::IpcWriteOptions::default(); - let mut schema_flight_data: FlightData = SchemaAsIpc::new(&schema, &options).into(); + #[allow(deprecated)] + let mut dict_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let data_gen = writer::IpcDataGenerator::default(); + let data = IpcMessage( + data_gen + .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options) + .ipc_message + .into(), + ); + let mut schema_flight_data = FlightData { + data_header: data.0, + ..Default::default() + }; // arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options); schema_flight_data.flight_descriptor = Some(descriptor.clone()); upload_tx.send(schema_flight_data).await?; @@ -82,7 +95,14 @@ async fn upload_data( if let Some((counter, first_batch)) = original_data_iter.next() { let metadata = counter.to_string().into_bytes(); // Preload the first batch into the channel before starting the request - send_batch(&mut upload_tx, &metadata, first_batch, &options).await?; + send_batch( + &mut upload_tx, + &metadata, + first_batch, + &options, + &mut dict_tracker, + ) + .await?; let outer = client.do_put(Request::new(upload_rx)).await?; let mut inner = outer.into_inner(); @@ -97,7 +117,14 @@ async fn upload_data( // Stream the rest of the batches for (counter, batch) in original_data_iter { let metadata = counter.to_string().into_bytes(); - send_batch(&mut upload_tx, &metadata, batch, &options).await?; + send_batch( + &mut upload_tx, + &metadata, + batch, + &options, + &mut dict_tracker, + ) + .await?; let r = inner .next() @@ -124,12 +151,12 @@ async fn send_batch( metadata: &[u8], batch: &RecordBatch, options: &writer::IpcWriteOptions, + dictionary_tracker: &mut writer::DictionaryTracker, ) -> Result { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, options) + .encoded_batch(batch, dictionary_tracker, options) .expect("DictionaryTracker configured above to not error on replacement"); let dictionary_flight_data: Vec = diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 0f404b2ae289..92989a20393e 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -119,18 +119,32 @@ impl FlightService for FlightServiceImpl { .ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?; let options = arrow::ipc::writer::IpcWriteOptions::default(); + #[allow(deprecated)] + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let data_gen = writer::IpcDataGenerator::default(); + let data = IpcMessage( + data_gen + .schema_to_bytes_with_dictionary_tracker( + &flight.schema, + &mut dictionary_tracker, + &options, + ) + .ipc_message + .into(), + ); + let schema_flight_data = FlightData { + data_header: data.0, + ..Default::default() + }; - let schema = std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); + let schema = std::iter::once(Ok(schema_flight_data)); let batches = flight .chunks .iter() .enumerate() .flat_map(|(counter, batch)| { - let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, true); - let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, &options) .expect("DictionaryTracker configured above to not error on replacement"); diff --git a/arrow-integration-testing/src/lib.rs b/arrow-integration-testing/src/lib.rs index c8ce01e9f13b..e669690ef4f5 100644 --- a/arrow-integration-testing/src/lib.rs +++ b/arrow-integration-testing/src/lib.rs @@ -17,6 +17,8 @@ //! Common code used in the integration test binaries +// The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets +#![allow(unused_crate_dependencies)] #![warn(missing_docs)] use serde_json::Value; diff --git a/arrow-ipc/Cargo.toml b/arrow-ipc/Cargo.toml index 94b89a55f2fb..cf91b3a3415f 100644 --- a/arrow-ipc/Cargo.toml +++ b/arrow-ipc/Cargo.toml @@ -36,7 +36,6 @@ bench = false [dependencies] arrow-array = { workspace = true } arrow-buffer = { workspace = true } -arrow-cast = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } flatbuffers = { version = "24.3.25", default-features = false } diff --git a/arrow-ipc/LICENSE.txt b/arrow-ipc/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-ipc/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-ipc/NOTICE.txt b/arrow-ipc/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-ipc/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 18f5193bf038..37c5a19439c1 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -165,6 +165,7 @@ pub fn schema_to_fb_offset<'a>( impl From> for Field { fn from(field: crate::Field) -> Field { let arrow_field = if let Some(dictionary) = field.dictionary() { + #[allow(deprecated)] Field::new_dict( field.name().unwrap(), get_data_type(field, true), @@ -519,6 +520,7 @@ pub(crate) fn build_field<'a>( match dictionary_tracker { Some(tracker) => Some(get_fb_dictionary( index_type, + #[allow(deprecated)] tracker.set_dict_id(field), field .dict_is_ordered() @@ -527,6 +529,7 @@ pub(crate) fn build_field<'a>( )), None => Some(get_fb_dictionary( index_type, + #[allow(deprecated)] field .dict_id() .expect("Dictionary type must have a dictionary id"), @@ -1026,10 +1029,14 @@ mod tests { Field::new("utf8_view", DataType::Utf8View, false), Field::new("binary", DataType::Binary, false), Field::new("binary_view", DataType::BinaryView, false), - Field::new_list("list[u8]", Field::new("item", DataType::UInt8, false), true), + Field::new_list( + "list[u8]", + Field::new_list_field(DataType::UInt8, false), + true, + ), Field::new_fixed_size_list( "fixed_size_list[u8]", - Field::new("item", DataType::UInt8, false), + Field::new_list_field(DataType::UInt8, false), 2, true, ), @@ -1139,6 +1146,7 @@ mod tests { ), true, ), + #[allow(deprecated)] Field::new_dict( "dictionary", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), @@ -1146,6 +1154,7 @@ mod tests { 123, true, ), + #[allow(deprecated)] Field::new_dict( "dictionary", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 0820e3590827..9ff4da30ed8c 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -196,6 +196,7 @@ fn create_array( let index_node = reader.next_node(field)?; let index_buffers = [reader.next_buffer()?, reader.next_buffer()?]; + #[allow(deprecated)] let dict_id = field.dict_id().ok_or_else(|| { ArrowError::ParseError(format!("Field {field} does not have dict id")) })?; @@ -617,6 +618,7 @@ fn read_dictionary_impl( } let id = batch.id(); + #[allow(deprecated)] let fields_using_this_dictionary = schema.fields_with_dict_id(id); let first_field = fields_using_this_dictionary.first().ok_or_else(|| { ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) @@ -1395,7 +1397,7 @@ impl RecordBatchReader for StreamReader { #[cfg(test)] mod tests { - use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator}; + use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; use super::*; @@ -1407,10 +1409,10 @@ mod tests { fn create_test_projection_schema() -> Schema { // define field types - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let fixed_size_list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3); let union_fields = UnionFields::new( vec![0, 1], @@ -1424,7 +1426,7 @@ mod tests { let struct_fields = Fields::from(vec![ Field::new("id", DataType::Int32, false), - Field::new_list("list", Field::new("item", DataType::Int8, true), false), + Field::new_list("list", Field::new_list_field(DataType::Int8, true), false), ]); let struct_data_type = DataType::Struct(struct_fields); @@ -1702,6 +1704,42 @@ mod tests { assert_eq!(batch, roundtrip_ipc(&batch)); } + #[test] + fn test_roundtrip_nested_dict_no_preserve_dict_id() { + let inner: DictionaryArray = vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false)); + + let s = StructArray::from(vec![(dctfield, array)]); + let struct_array = Arc::new(s) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "struct", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); + + let mut buf = Vec::new(); + let mut writer = crate::writer::FileWriter::try_new_with_options( + &mut buf, + batch.schema_ref(), + #[allow(deprecated)] + IpcWriteOptions::default().with_preserve_dict_id(false), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + drop(writer); + + let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); + + assert_eq!(batch, reader.next().unwrap().unwrap()); + } + fn check_union_with_builder(mut builder: UnionBuilder) { builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); @@ -1743,7 +1781,7 @@ mod tests { #[test] fn test_roundtrip_struct_empty_fields() { - let nulls = NullBuffer::from(&[true, true, false][..]); + let nulls = NullBuffer::from(&[true, true, false]); let rb = RecordBatch::try_from_iter([( "", Arc::new(StructArray::new_empty_fields(nulls.len(), Some(nulls))) as _, @@ -1834,6 +1872,7 @@ mod tests { let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]); let key_dict_array = DictionaryArray::new(key_dict_keys, values); + #[allow(deprecated)] let keys_field = Arc::new(Field::new_dict( "keys", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), @@ -1841,6 +1880,7 @@ mod tests { 1, false, )); + #[allow(deprecated)] let values_field = Arc::new(Field::new_dict( "values", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), @@ -1921,6 +1961,7 @@ mod tests { #[test] fn test_roundtrip_stream_dict_of_list_of_dict() { // list + #[allow(deprecated)] let list_data_type = DataType::List(Arc::new(Field::new_dict( "item", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), @@ -1932,6 +1973,7 @@ mod tests { test_roundtrip_stream_dict_of_list_of_dict_impl::(list_data_type, offsets); // large list + #[allow(deprecated)] let list_data_type = DataType::LargeList(Arc::new(Field::new_dict( "item", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), @@ -1950,6 +1992,7 @@ mod tests { let dict_array = DictionaryArray::new(keys, Arc::new(values)); let dict_data = dict_array.into_data(); + #[allow(deprecated)] let list_data_type = DataType::FixedSizeList( Arc::new(Field::new_dict( "item", @@ -2040,6 +2083,7 @@ mod tests { let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]); let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone()); + #[allow(deprecated)] let keys_field = Arc::new(Field::new_dict( "keys", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)), @@ -2050,6 +2094,7 @@ mod tests { let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]); let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array); + #[allow(deprecated)] let values_field = Arc::new(Field::new_dict( "values", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)), @@ -2115,6 +2160,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; + #[allow(deprecated)] let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) @@ -2152,6 +2198,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; + #[allow(deprecated)] let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) @@ -2291,6 +2338,7 @@ mod tests { ["a", "b"] .iter() .map(|name| { + #[allow(deprecated)] Field::new_dict( name.to_string(), DataType::Dictionary( @@ -2325,6 +2373,7 @@ mod tests { let mut writer = crate::writer::StreamWriter::try_new_with_options( &mut buf, batch.schema().as_ref(), + #[allow(deprecated)] crate::writer::IpcWriteOptions::default().with_preserve_dict_id(false), ) .expect("Failed to create StreamWriter"); diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index de5f5bdd629f..9b0eea9b6198 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -324,6 +324,7 @@ mod tests { "test1", DataType::RunEndEncoded( Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)), + #[allow(deprecated)] Arc::new(Field::new_dict( "values".to_string(), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), @@ -353,6 +354,7 @@ mod tests { let mut writer = StreamWriter::try_new_with_options( &mut buffer, &schema, + #[allow(deprecated)] IpcWriteOptions::default().with_preserve_dict_id(false), ) .expect("Failed to create StreamWriter"); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index b5c4dd95ed9f..ee5b9a54cc90 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -23,6 +23,7 @@ use std::cmp::min; use std::collections::HashMap; use std::io::{BufWriter, Write}; +use std::mem::size_of; use std::sync::Arc; use flatbuffers::FlatBufferBuilder; @@ -63,7 +64,11 @@ pub struct IpcWriteOptions { /// Flag indicating whether the writer should preserve the dictionary IDs defined in the /// schema or generate unique dictionary IDs internally during encoding. /// - /// Defaults to `true` + /// Defaults to `false` + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it." + )] preserve_dict_id: bool, } @@ -107,12 +112,13 @@ impl IpcWriteOptions { | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError( "Writing IPC metadata version 3 and lower not supported".to_string(), )), + #[allow(deprecated)] crate::MetadataVersion::V4 => Ok(Self { alignment, write_legacy_ipc_format, metadata_version, batch_compression_type: None, - preserve_dict_id: true, + preserve_dict_id: false, }), crate::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -120,12 +126,13 @@ impl IpcWriteOptions { "Legacy IPC format only supported on metadata version 4".to_string(), )) } else { + #[allow(deprecated)] Ok(Self { alignment, write_legacy_ipc_format, metadata_version, batch_compression_type: None, - preserve_dict_id: true, + preserve_dict_id: false, }) } } @@ -137,7 +144,12 @@ impl IpcWriteOptions { /// Return whether the writer is configured to preserve the dictionary IDs /// defined in the schema + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it." + )] pub fn preserve_dict_id(&self) -> bool { + #[allow(deprecated)] self.preserve_dict_id } @@ -148,6 +160,11 @@ impl IpcWriteOptions { /// to the dictionary batches in order to encode them correctly /// /// The default will change to `false` in future releases + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it." + )] + #[allow(deprecated)] pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self { self.preserve_dict_id = preserve_dict_id; self @@ -156,12 +173,13 @@ impl IpcWriteOptions { impl Default for IpcWriteOptions { fn default() -> Self { + #[allow(deprecated)] Self { alignment: 64, write_legacy_ipc_format: false, metadata_version: crate::MetadataVersion::V5, batch_compression_type: None, - preserve_dict_id: true, + preserve_dict_id: false, } } } @@ -419,6 +437,7 @@ impl IpcDataGenerator { // It's importnat to only take the dict_id at this point, because the dict ID // sequence is assigned depth-first, so we need to first encode children and have // them take their assigned dict IDs before we take the dict ID for this field. + #[allow(deprecated)] let dict_id = dict_id_seq .next() .or_else(|| field.dict_id()) @@ -766,6 +785,10 @@ pub struct DictionaryTracker { written: HashMap, dict_ids: Vec, error_on_replacement: bool, + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it." + )] preserve_dict_id: bool, } @@ -781,11 +804,12 @@ impl DictionaryTracker { /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been /// seen) pub fn new(error_on_replacement: bool) -> Self { + #[allow(deprecated)] Self { written: HashMap::new(), dict_ids: Vec::new(), error_on_replacement, - preserve_dict_id: true, + preserve_dict_id: false, } } @@ -794,7 +818,12 @@ impl DictionaryTracker { /// If `error_on_replacement` /// is true, an error will be generated if an update to an /// existing dictionary is attempted. + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it." + )] pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self { + #[allow(deprecated)] Self { written: HashMap::new(), dict_ids: Vec::new(), @@ -810,8 +839,14 @@ impl DictionaryTracker { /// /// If `preserve_dict_id` is false, this will return the value of the last `dict_id` assigned incremented by 1 /// or 0 in the case where no dictionary IDs have yet been assigned + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it." + )] pub fn set_dict_id(&mut self, field: &Field) -> i64 { + #[allow(deprecated)] let next = if self.preserve_dict_id { + #[allow(deprecated)] field.dict_id().expect("no dict_id in field") } else { self.dict_ids @@ -935,7 +970,9 @@ impl FileWriter { writer.write_all(&super::ARROW_MAGIC)?; writer.write_all(&PADDING[..pad_len])?; // write the schema, set the written bytes to the schema + header + #[allow(deprecated)] let preserve_dict_id = write_options.preserve_dict_id; + #[allow(deprecated)] let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id); let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker( @@ -1012,8 +1049,13 @@ impl FileWriter { let mut fbb = FlatBufferBuilder::new(); let dictionaries = fbb.create_vector(&self.dictionary_blocks); let record_batches = fbb.create_vector(&self.record_blocks); + #[allow(deprecated)] + let preserve_dict_id = self.write_options.preserve_dict_id; + #[allow(deprecated)] + let mut dictionary_tracker = + DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id); let schema = IpcSchemaEncoder::new() - .with_dictionary_tracker(&mut self.dictionary_tracker) + .with_dictionary_tracker(&mut dictionary_tracker) .schema_to_fb_offset(&mut fbb, &self.schema); let fb_custom_metadata = (!self.custom_metadata.is_empty()) .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata)); @@ -1140,7 +1182,9 @@ impl StreamWriter { write_options: IpcWriteOptions, ) -> Result { let data_gen = IpcDataGenerator::default(); + #[allow(deprecated)] let preserve_dict_id = write_options.preserve_dict_id; + #[allow(deprecated)] let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, preserve_dict_id); @@ -1427,7 +1471,13 @@ fn reencode_offsets( let end_offset = offset_slice.last().unwrap(); let offsets = match start_offset.as_usize() { - 0 => offsets.clone(), + 0 => { + let size = size_of::(); + offsets.slice_with_length( + data.offset() * size, + (data.offset() + data.len() + 1) * size, + ) + } _ => offset_slice.iter().map(|x| *x - *start_offset).collect(), }; @@ -2022,6 +2072,7 @@ mod tests { let array = Arc::new(inner) as ArrayRef; // Dict field with id 2 + #[allow(deprecated)] let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 2, false); let union_fields = [(0, Arc::new(dctfield))].into_iter().collect(); @@ -2039,6 +2090,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); let gen = IpcDataGenerator {}; + #[allow(deprecated)] let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); @@ -2055,6 +2107,7 @@ mod tests { let array = Arc::new(inner) as ArrayRef; // Dict field with id 2 + #[allow(deprecated)] let dctfield = Arc::new(Field::new_dict( "dict", array.data_type().clone(), @@ -2075,6 +2128,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); let gen = IpcDataGenerator {}; + #[allow(deprecated)] let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); @@ -2514,6 +2568,36 @@ mod tests { ls.finish() } + fn generate_nested_list_data_starting_at_zero() -> GenericListArray { + let mut ls = + GenericListBuilder::::new(GenericListBuilder::::new(UInt32Builder::new())); + + for _i in 0..999 { + ls.values().append(true); + ls.append(true); + } + + for j in 0..10 { + for value in [j, j, j, j] { + ls.values().values().append_value(value); + } + ls.values().append(true) + } + ls.append(true); + + for i in 0..9_000 { + for j in 0..10 { + for value in [i + j, i + j, i + j, i + j] { + ls.values().values().append_value(value); + } + ls.values().append(true) + } + ls.append(true); + } + + ls.finish() + } + fn generate_map_array_data() -> MapArray { let keys_builder = UInt32Builder::new(); let values_builder = UInt32Builder::new(); @@ -2553,7 +2637,7 @@ mod tests { #[test] fn encode_lists() { - let val_inner = Field::new("item", DataType::UInt32, true); + let val_inner = Field::new_list_field(DataType::UInt32, true); let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false); let schema = Arc::new(Schema::new(vec![val_list_field])); @@ -2565,7 +2649,7 @@ mod tests { #[test] fn encode_empty_list() { - let val_inner = Field::new("item", DataType::UInt32, true); + let val_inner = Field::new_list_field(DataType::UInt32, true); let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false); let schema = Arc::new(Schema::new(vec![val_list_field])); @@ -2580,7 +2664,7 @@ mod tests { #[test] fn encode_large_lists() { - let val_inner = Field::new("item", DataType::UInt32, true); + let val_inner = Field::new_list_field(DataType::UInt32, true); let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false); let schema = Arc::new(Schema::new(vec![val_list_field])); @@ -2594,8 +2678,8 @@ mod tests { #[test] fn encode_nested_lists() { - let inner_int = Arc::new(Field::new("item", DataType::UInt32, true)); - let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true)); + let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true)); + let inner_list_field = Arc::new(Field::new_list_field(DataType::List(inner_int), true)); let list_field = Field::new("val", DataType::List(inner_list_field), true); let schema = Arc::new(Schema::new(vec![list_field])); @@ -2605,6 +2689,19 @@ mod tests { roundtrip_ensure_sliced_smaller(in_batch, 1000); } + #[test] + fn encode_nested_lists_starting_at_zero() { + let inner_int = Arc::new(Field::new("item", DataType::UInt32, true)); + let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true)); + let list_field = Field::new("val", DataType::List(inner_list_field), true); + let schema = Arc::new(Schema::new(vec![list_field])); + + let values = Arc::new(generate_nested_list_data_starting_at_zero::()); + + let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + roundtrip_ensure_sliced_smaller(in_batch, 1); + } + #[test] fn encode_map_array() { let keys = Arc::new(Field::new("keys", DataType::UInt32, false)); diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml index 517bb03d2064..564cb9433b3d 100644 --- a/arrow-json/Cargo.toml +++ b/arrow-json/Cargo.toml @@ -48,7 +48,6 @@ chrono = { workspace = true } lexical-core = { version = "1.0", default-features = false} [dev-dependencies] -tempfile = "3.3" flate2 = { version = "1", default-features = false, features = ["rust_backend"] } serde = { version = "1.0", default-features = false, features = ["derive"] } futures = "0.3" diff --git a/arrow-json/LICENSE.txt b/arrow-json/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-json/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-json/NOTICE.txt b/arrow-json/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-json/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index bcacf6f706b8..f857e8813c7e 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -244,13 +244,6 @@ impl ReaderBuilder { Self { batch_size, ..self } } - /// Sets if the decoder should coerce primitive values (bool and number) into string - /// when the Schema's column is Utf8 or LargeUtf8. - #[deprecated(note = "Use with_coerce_primitive")] - pub fn coerce_primitive(self, coerce_primitive: bool) -> Self { - self.with_coerce_primitive(coerce_primitive) - } - /// Sets if the decoder should coerce primitive values (bool and number) into string /// when the Schema's column is Utf8 or LargeUtf8. pub fn with_coerce_primitive(self, coerce_primitive: bool) -> Self { @@ -691,6 +684,10 @@ fn make_decoder( DataType::Time32(TimeUnit::Millisecond) => primitive_decoder!(Time32MillisecondType, data_type), DataType::Time64(TimeUnit::Microsecond) => primitive_decoder!(Time64MicrosecondType, data_type), DataType::Time64(TimeUnit::Nanosecond) => primitive_decoder!(Time64NanosecondType, data_type), + DataType::Duration(TimeUnit::Nanosecond) => primitive_decoder!(DurationNanosecondType, data_type), + DataType::Duration(TimeUnit::Microsecond) => primitive_decoder!(DurationMicrosecondType, data_type), + DataType::Duration(TimeUnit::Millisecond) => primitive_decoder!(DurationMillisecondType, data_type), + DataType::Duration(TimeUnit::Second) => primitive_decoder!(DurationSecondType, data_type), DataType::Decimal128(p, s) => Ok(Box::new(DecimalArrayDecoder::::new(p, s))), DataType::Decimal256(p, s) => Ok(Box::new(DecimalArrayDecoder::::new(p, s))), DataType::Boolean => Ok(Box::::default()), @@ -1330,6 +1327,37 @@ mod tests { test_time::(); } + fn test_duration() { + let buf = r#" + {"a": 1, "b": "2"} + {"a": 3, "b": null} + "#; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", T::DATA_TYPE, true), + Field::new("b", T::DATA_TYPE, true), + ])); + + let batches = do_read(buf, 1024, true, false, schema); + assert_eq!(batches.len(), 1); + + let col_a = batches[0].column_by_name("a").unwrap().as_primitive::(); + assert_eq!(col_a.null_count(), 0); + assert_eq!(col_a.values(), &[1, 3].map(T::Native::usize_as)); + + let col2 = batches[0].column_by_name("b").unwrap().as_primitive::(); + assert_eq!(col2.null_count(), 1); + assert_eq!(col2.values(), &[2, 0].map(T::Native::usize_as)); + } + + #[test] + fn test_durations() { + test_duration::(); + test_duration::(); + test_duration::(); + test_duration::(); + } + #[test] fn test_delta_checkpoint() { let json = "{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}"; @@ -1726,12 +1754,12 @@ mod tests { assert_eq!(&DataType::Int64, a.1.data_type()); let b = schema.column_with_name("b").unwrap(); assert_eq!( - &DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))), b.1.data_type() ); let c = schema.column_with_name("c").unwrap(); assert_eq!( - &DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + &DataType::List(Arc::new(Field::new_list_field(DataType::Boolean, true))), c.1.data_type() ); let d = schema.column_with_name("d").unwrap(); @@ -1770,7 +1798,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new( "items", - DataType::List(FieldRef::new(Field::new("item", DataType::Null, true))), + DataType::List(FieldRef::new(Field::new_list_field(DataType::Null, true))), true, )])); @@ -1794,9 +1822,8 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new( "items", - DataType::List(FieldRef::new(Field::new( - "item", - DataType::List(FieldRef::new(Field::new("item", DataType::Null, true))), + DataType::List(FieldRef::new(Field::new_list_field( + DataType::List(FieldRef::new(Field::new_list_field(DataType::Null, true))), true, ))), true, diff --git a/arrow-json/src/reader/schema.rs b/arrow-json/src/reader/schema.rs index ace7b0ea5cb6..07eb40106de0 100644 --- a/arrow-json/src/reader/schema.rs +++ b/arrow-json/src/reader/schema.rs @@ -77,7 +77,7 @@ impl InferredType { /// Shorthand for building list data type of `ty` fn list_type_of(ty: DataType) -> DataType { - DataType::List(Arc::new(Field::new("item", ty, true))) + DataType::List(Arc::new(Field::new_list_field(ty, true))) } /// Coerce data type during inference diff --git a/arrow-json/src/writer/mod.rs b/arrow-json/src/writer/mod.rs index a37aa5ff8c2c..ee6d83a0a1f0 100644 --- a/arrow-json/src/writer/mod.rs +++ b/arrow-json/src/writer/mod.rs @@ -1771,7 +1771,7 @@ mod tests { #[test] fn test_writer_fixed_size_list() { let size = 3; - let field = FieldRef::new(Field::new("item", DataType::Int32, true)); + let field = FieldRef::new(Field::new_list_field(DataType::Int32, true)); let schema = SchemaRef::new(Schema::new(vec![Field::new( "list", DataType::FixedSizeList(field, size), diff --git a/arrow-ord/Cargo.toml b/arrow-ord/Cargo.toml index c9c30074fe6e..8d74d2f97d72 100644 --- a/arrow-ord/Cargo.toml +++ b/arrow-ord/Cargo.toml @@ -39,8 +39,7 @@ arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } -num = { version = "0.4", default-features = false, features = ["std"] } -half = { version = "2.1", default-features = false, features = ["num-traits"] } [dev-dependencies] +half = { version = "2.1", default-features = false, features = ["num-traits"] } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } diff --git a/arrow-ord/LICENSE.txt b/arrow-ord/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-ord/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-ord/NOTICE.txt b/arrow-ord/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-ord/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index f571e26c444c..2727ff996150 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -656,7 +656,10 @@ pub fn compare_byte_view( /// /// # Safety /// The left/right_idx must within range of each array -#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")] +#[deprecated( + since = "52.2.0", + note = "Use `GenericByteViewArray::compare_unchecked` instead" +)] pub unsafe fn compare_byte_view_unchecked( left: &GenericByteViewArray, left_idx: usize, diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index d60bc3b8de88..bb82f54d4918 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -821,7 +821,7 @@ mod tests { .into_data(); let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 6, 9]); let list_data_type = - DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(4) .add_buffer(value_offsets) diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 6430c8f0e405..55e397cd8aa4 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -265,7 +265,7 @@ fn compare_struct( Ok(f) } -#[deprecated(note = "Use make_comparator")] +#[deprecated(since = "52.0.0", note = "Use make_comparator")] #[doc(hidden)] pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { make_comparator(left, right, SortOptions::default()) @@ -394,7 +394,7 @@ pub fn make_comparator( } #[cfg(test)] -pub mod tests { +mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder}; use arrow_buffer::{i256, IntervalDayTime, OffsetBuffer}; @@ -849,7 +849,7 @@ pub mod tests { fn test_struct() { let fields = Fields::from(vec![ Field::new("a", DataType::Int32, true), - Field::new_list("b", Field::new("item", DataType::Int32, true), true), + Field::new_list("b", Field::new_list_field(DataType::Int32, true), true), ]); let a = Int32Array::from(vec![Some(1), Some(2), None, None]); diff --git a/arrow-ord/src/partition.rs b/arrow-ord/src/partition.rs index 8c87eefadbf0..ec1647393239 100644 --- a/arrow-ord/src/partition.rs +++ b/arrow-ord/src/partition.rs @@ -24,7 +24,6 @@ use arrow_buffer::BooleanBuffer; use arrow_schema::ArrowError; use crate::cmp::distinct; -use crate::sort::SortColumn; /// A computed set of partitions, see [`partition`] #[derive(Debug, Clone)] @@ -160,21 +159,6 @@ fn find_boundaries(v: &dyn Array) -> Result { Ok(distinct(&v1, &v2)?.values().clone()) } -/// Use [`partition`] instead. Given a list of already sorted columns, find -/// partition ranges that would partition lexicographically equal values across -/// columns. -/// -/// The returned vec would be of size k where k is cardinality of the sorted values; Consecutive -/// values will be connected: (a, b) and (b, c), where start = 0 and end = n for the first and last -/// range. -#[deprecated(note = "Use partition")] -pub fn lexicographical_partition_ranges( - columns: &[SortColumn], -) -> Result> + '_, ArrowError> { - let cols: Vec<_> = columns.iter().map(|x| x.values.clone()).collect(); - Ok(partition(&cols)?.ranges().into_iter()) -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs index ecc693bab4e4..e61cebef38ec 100644 --- a/arrow-ord/src/rank.rs +++ b/arrow-ord/src/rank.rs @@ -24,6 +24,15 @@ use arrow_buffer::NullBuffer; use arrow_schema::{ArrowError, DataType, SortOptions}; use std::cmp::Ordering; +/// Whether `arrow_ord::rank` can rank an array of given data type. +pub(crate) fn can_rank(data_type: &DataType) -> bool { + data_type.is_primitive() + || matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary + ) +} + /// Assigns a rank to each value in `array` based on its position in the sorted order /// /// Where values are equal, they will be assigned the highest of their ranks, diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 60fc4a918525..51a6659e631b 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -30,7 +30,7 @@ use arrow_select::take::take; use std::cmp::Ordering; use std::sync::Arc; -use crate::rank::rank; +use crate::rank::{can_rank, rank}; pub use arrow_schema::SortOptions; /// Sort the `ArrayRef` using `SortOptions`. @@ -190,15 +190,6 @@ fn partition_validity(array: &dyn Array) -> (Vec, Vec) { } } -/// Whether `arrow_ord::rank` can rank an array of given data type. -fn can_rank(data_type: &DataType) -> bool { - data_type.is_primitive() - || matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary - ) -} - /// Whether `sort_to_indices` can sort an array of given data type. fn can_sort_to_indices(data_type: &DataType) -> bool { data_type.is_primitive() diff --git a/arrow-pyarrow-integration-testing/Cargo.toml b/arrow-pyarrow-integration-testing/Cargo.toml index 0834f2d13384..03d08df30959 100644 --- a/arrow-pyarrow-integration-testing/Cargo.toml +++ b/arrow-pyarrow-integration-testing/Cargo.toml @@ -34,4 +34,4 @@ crate-type = ["cdylib"] [dependencies] arrow = { path = "../arrow", features = ["pyarrow"] } -pyo3 = { version = "0.22", features = ["extension-module"] } +pyo3 = { version = "0.23", features = ["extension-module"] } diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index e12c1389e66f..d4908fff0897 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -43,7 +43,7 @@ fn to_py_err(err: ArrowError) -> PyErr { #[pyfunction] fn double(array: &Bound, py: Python) -> PyResult { // import - let array = make_array(ArrayData::from_pyarrow_bound(&array)?); + let array = make_array(ArrayData::from_pyarrow_bound(array)?); // perform some operation let array = array diff --git a/arrow-row/Cargo.toml b/arrow-row/Cargo.toml index 3754afb4dbc6..90d99684d265 100644 --- a/arrow-row/Cargo.toml +++ b/arrow-row/Cargo.toml @@ -33,12 +33,6 @@ name = "arrow_row" path = "src/lib.rs" bench = false -[target.'cfg(target_arch = "wasm32")'.dependencies] -ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] } - -[target.'cfg(not(target_arch = "wasm32"))'.dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } - [dependencies] arrow-array = { workspace = true } arrow-buffer = { workspace = true } diff --git a/arrow-row/LICENSE.txt b/arrow-row/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-row/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-row/NOTICE.txt b/arrow-row/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-row/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 5780bdbfefb9..d0fad12210db 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -2317,7 +2317,7 @@ mod tests { let values_len = offsets.last().unwrap().to_usize().unwrap(); let values = values(values_len); let nulls = NullBuffer::from_iter((0..len).map(|_| rng.gen_bool(valid_percent))); - let field = Arc::new(Field::new("item", values.data_type().clone(), true)); + let field = Arc::new(Field::new_list_field(values.data_type().clone(), true)); ListArray::new(field, offsets, values, Some(nulls)) } diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index 628d4a683cac..1e1f9fbde0e4 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -47,3 +47,8 @@ features = ["ffi"] [dev-dependencies] serde_json = "1.0" bincode = { version = "1.3.3", default-features = false } +criterion = { version = "0.5", default-features = false } + +[[bench]] +name = "ffi" +harness = false \ No newline at end of file diff --git a/arrow-schema/LICENSE.txt b/arrow-schema/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-schema/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-schema/NOTICE.txt b/arrow-schema/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-schema/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-schema/benches/ffi.rs b/arrow-schema/benches/ffi.rs new file mode 100644 index 000000000000..1285acb883ea --- /dev/null +++ b/arrow-schema/benches/ffi.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::{DataType, Field}; +use criterion::*; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let fields = vec![ + Arc::new(Field::new("c1", DataType::Utf8, false)), + Arc::new(Field::new("c2", DataType::Utf8, false)), + Arc::new(Field::new("c3", DataType::Utf8, false)), + Arc::new(Field::new("c4", DataType::Utf8, false)), + Arc::new(Field::new("c5", DataType::Utf8, false)), + ]; + let data_type = DataType::Struct(fields.into()); + c.bench_function("ffi_arrow_schema_try_from", |b| { + b.iter(|| FFI_ArrowSchema::try_from(&data_type)); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index ff5832dfa68c..7cd53b13c73e 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -40,7 +40,7 @@ use crate::{ArrowError, Field, FieldRef, Fields, UnionFields}; /// # use arrow_schema::{DataType, Field}; /// # use std::sync::Arc; /// // create a new list of 32-bit signed integers directly -/// let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// // Create the same list type with constructor /// let list_data_type2 = DataType::new_list(DataType::Int32, true); /// assert_eq!(list_data_type, list_data_type2); @@ -196,6 +196,14 @@ pub enum DataType { /// DataType::Timestamp(TimeUnit::Second, Some("literal".into())); /// DataType::Timestamp(TimeUnit::Second, Some("string".to_string().into())); /// ``` + /// + /// Timezone string parsing + /// ----------------------- + /// When feature `chrono-tz` is not enabled, allowed timezone strings are fixed offsets of the form "+09:00", "-09" or "+0930". + /// + /// When feature `chrono-tz` is enabled, additional strings supported by [chrono_tz](https://docs.rs/chrono-tz/latest/chrono_tz/) + /// are also allowed, which include [IANA database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) + /// timezones. Timestamp(TimeUnit, Option>), /// A signed 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) /// in days. @@ -837,21 +845,21 @@ mod tests { #[test] fn test_list_datatype_equality() { // tests that list type equality is checked while ignoring list names - let list_a = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_a = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_b = DataType::List(Arc::new(Field::new("array", DataType::Int32, true))); - let list_c = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); - let list_d = DataType::List(Arc::new(Field::new("item", DataType::UInt32, true))); + let list_c = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); + let list_d = DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))); assert!(list_a.equals_datatype(&list_b)); assert!(!list_a.equals_datatype(&list_c)); assert!(!list_b.equals_datatype(&list_c)); assert!(!list_a.equals_datatype(&list_d)); let list_e = - DataType::FixedSizeList(Arc::new(Field::new("item", list_a.clone(), false)), 3); + DataType::FixedSizeList(Arc::new(Field::new_list_field(list_a.clone(), false)), 3); let list_f = DataType::FixedSizeList(Arc::new(Field::new("array", list_b.clone(), false)), 3); let list_g = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::FixedSizeBinary(3), true)), + Arc::new(Field::new_list_field(DataType::FixedSizeBinary(3), true)), 3, ); assert!(list_e.equals_datatype(&list_f)); diff --git a/arrow-schema/src/datatype_parse.rs b/arrow-schema/src/datatype_parse.rs index 4378950329f3..bf557d8941dc 100644 --- a/arrow-schema/src/datatype_parse.rs +++ b/arrow-schema/src/datatype_parse.rs @@ -90,8 +90,8 @@ impl<'a> Parser<'a> { self.expect_token(Token::LParen)?; let data_type = self.parse_next_type()?; self.expect_token(Token::RParen)?; - Ok(DataType::List(Arc::new(Field::new( - "item", data_type, true, + Ok(DataType::List(Arc::new(Field::new_list_field( + data_type, true, )))) } @@ -100,8 +100,8 @@ impl<'a> Parser<'a> { self.expect_token(Token::LParen)?; let data_type = self.parse_next_type()?; self.expect_token(Token::RParen)?; - Ok(DataType::LargeList(Arc::new(Field::new( - "item", data_type, true, + Ok(DataType::LargeList(Arc::new(Field::new_list_field( + data_type, true, )))) } @@ -113,7 +113,7 @@ impl<'a> Parser<'a> { let data_type = self.parse_next_type()?; self.expect_token(Token::RParen)?; Ok(DataType::FixedSizeList( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), length, )) } diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs index 70650d769cf6..96c80974982c 100644 --- a/arrow-schema/src/ffi.rs +++ b/arrow-schema/src/ffi.rs @@ -38,6 +38,7 @@ use crate::{ ArrowError, DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, }; use bitflags::bitflags; +use std::borrow::Cow; use std::sync::Arc; use std::{ collections::HashMap, @@ -685,57 +686,59 @@ impl TryFrom<&DataType> for FFI_ArrowSchema { } } -fn get_format_string(dtype: &DataType) -> Result { +fn get_format_string(dtype: &DataType) -> Result, ArrowError> { match dtype { - DataType::Null => Ok("n".to_string()), - DataType::Boolean => Ok("b".to_string()), - DataType::Int8 => Ok("c".to_string()), - DataType::UInt8 => Ok("C".to_string()), - DataType::Int16 => Ok("s".to_string()), - DataType::UInt16 => Ok("S".to_string()), - DataType::Int32 => Ok("i".to_string()), - DataType::UInt32 => Ok("I".to_string()), - DataType::Int64 => Ok("l".to_string()), - DataType::UInt64 => Ok("L".to_string()), - DataType::Float16 => Ok("e".to_string()), - DataType::Float32 => Ok("f".to_string()), - DataType::Float64 => Ok("g".to_string()), - DataType::BinaryView => Ok("vz".to_string()), - DataType::Binary => Ok("z".to_string()), - DataType::LargeBinary => Ok("Z".to_string()), - DataType::Utf8View => Ok("vu".to_string()), - DataType::Utf8 => Ok("u".to_string()), - DataType::LargeUtf8 => Ok("U".to_string()), - DataType::FixedSizeBinary(num_bytes) => Ok(format!("w:{num_bytes}")), - DataType::FixedSizeList(_, num_elems) => Ok(format!("+w:{num_elems}")), - DataType::Decimal128(precision, scale) => Ok(format!("d:{precision},{scale}")), - DataType::Decimal256(precision, scale) => Ok(format!("d:{precision},{scale},256")), - DataType::Date32 => Ok("tdD".to_string()), - DataType::Date64 => Ok("tdm".to_string()), - DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()), - DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()), - DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()), - DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()), - DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()), - DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".to_string()), - DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".to_string()), - DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".to_string()), - DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(format!("tss:{tz}")), - DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(format!("tsm:{tz}")), - DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(format!("tsu:{tz}")), - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(format!("tsn:{tz}")), - DataType::Duration(TimeUnit::Second) => Ok("tDs".to_string()), - DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()), - DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()), - DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()), - DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".to_string()), - DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".to_string()), - DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".to_string()), - DataType::List(_) => Ok("+l".to_string()), - DataType::LargeList(_) => Ok("+L".to_string()), - DataType::Struct(_) => Ok("+s".to_string()), - DataType::Map(_, _) => Ok("+m".to_string()), - DataType::RunEndEncoded(_, _) => Ok("+r".to_string()), + DataType::Null => Ok("n".into()), + DataType::Boolean => Ok("b".into()), + DataType::Int8 => Ok("c".into()), + DataType::UInt8 => Ok("C".into()), + DataType::Int16 => Ok("s".into()), + DataType::UInt16 => Ok("S".into()), + DataType::Int32 => Ok("i".into()), + DataType::UInt32 => Ok("I".into()), + DataType::Int64 => Ok("l".into()), + DataType::UInt64 => Ok("L".into()), + DataType::Float16 => Ok("e".into()), + DataType::Float32 => Ok("f".into()), + DataType::Float64 => Ok("g".into()), + DataType::BinaryView => Ok("vz".into()), + DataType::Binary => Ok("z".into()), + DataType::LargeBinary => Ok("Z".into()), + DataType::Utf8View => Ok("vu".into()), + DataType::Utf8 => Ok("u".into()), + DataType::LargeUtf8 => Ok("U".into()), + DataType::FixedSizeBinary(num_bytes) => Ok(Cow::Owned(format!("w:{num_bytes}"))), + DataType::FixedSizeList(_, num_elems) => Ok(Cow::Owned(format!("+w:{num_elems}"))), + DataType::Decimal128(precision, scale) => Ok(Cow::Owned(format!("d:{precision},{scale}"))), + DataType::Decimal256(precision, scale) => { + Ok(Cow::Owned(format!("d:{precision},{scale},256"))) + } + DataType::Date32 => Ok("tdD".into()), + DataType::Date64 => Ok("tdm".into()), + DataType::Time32(TimeUnit::Second) => Ok("tts".into()), + DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".into()), + DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".into()), + DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".into()), + DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".into()), + DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".into()), + DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".into()), + DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".into()), + DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(Cow::Owned(format!("tss:{tz}"))), + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(Cow::Owned(format!("tsm:{tz}"))), + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(Cow::Owned(format!("tsu:{tz}"))), + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(Cow::Owned(format!("tsn:{tz}"))), + DataType::Duration(TimeUnit::Second) => Ok("tDs".into()), + DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".into()), + DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".into()), + DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".into()), + DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".into()), + DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".into()), + DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".into()), + DataType::List(_) => Ok("+l".into()), + DataType::LargeList(_) => Ok("+L".into()), + DataType::Struct(_) => Ok("+s".into()), + DataType::Map(_, _) => Ok("+m".into()), + DataType::RunEndEncoded(_, _) => Ok("+r".into()), DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type), DataType::Union(fields, mode) => { let formats = fields @@ -743,8 +746,8 @@ fn get_format_string(dtype: &DataType) -> Result { .map(|(t, _)| t.to_string()) .collect::>(); match mode { - UnionMode::Dense => Ok(format!("{}:{}", "+ud", formats.join(","))), - UnionMode::Sparse => Ok(format!("{}:{}", "+us", formats.join(","))), + UnionMode::Dense => Ok(Cow::Owned(format!("{}:{}", "+ud", formats.join(",")))), + UnionMode::Sparse => Ok(Cow::Owned(format!("{}:{}", "+us", formats.join(",")))), } } other => Err(ArrowError::CDataInterface(format!( @@ -920,6 +923,7 @@ mod tests { #[test] fn test_dictionary_ordered() { + #[allow(deprecated)] let schema = Schema::new(vec![Field::new_dict( "dict", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index b532ea8616b6..13bb7abf51b4 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -38,6 +38,10 @@ pub struct Field { name: String, data_type: DataType, nullable: bool, + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it." + )] dict_id: i64, dict_is_ordered: bool, /// A map of key-value pairs containing additional custom meta data. @@ -117,8 +121,12 @@ impl Hash for Field { } impl Field { + /// Default list member field name + pub const LIST_FIELD_DEFAULT_NAME: &'static str = "item"; + /// Creates a new field with the given name, type, and nullability pub fn new(name: impl Into, data_type: DataType, nullable: bool) -> Self { + #[allow(deprecated)] Field { name: name.into(), data_type, @@ -144,10 +152,14 @@ impl Field { /// ); /// ``` pub fn new_list_field(data_type: DataType, nullable: bool) -> Self { - Self::new("item", data_type, nullable) + Self::new(Self::LIST_FIELD_DEFAULT_NAME, data_type, nullable) } /// Creates a new field that has additional dictionary information + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With the dict_id field disappearing this function signature will change by removing the dict_id parameter." + )] pub fn new_dict( name: impl Into, data_type: DataType, @@ -155,6 +167,7 @@ impl Field { dict_id: i64, dict_is_ordered: bool, ) -> Self { + #[allow(deprecated)] Field { name: name.into(), data_type, @@ -383,25 +396,49 @@ impl Field { /// Returns a vector containing all (potentially nested) `Field` instances selected by the /// dictionary ID they use #[inline] + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it." + )] pub(crate) fn fields_with_dict_id(&self, id: i64) -> Vec<&Field> { self.fields() .into_iter() .filter(|&field| { - matches!(field.data_type(), DataType::Dictionary(_, _)) && field.dict_id == id + #[allow(deprecated)] + let matching_dict_id = field.dict_id == id; + matches!(field.data_type(), DataType::Dictionary(_, _)) && matching_dict_id }) .collect() } /// Returns the dictionary ID, if this is a dictionary type. #[inline] + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it." + )] pub const fn dict_id(&self) -> Option { match self.data_type { + #[allow(deprecated)] DataType::Dictionary(_, _) => Some(self.dict_id), _ => None, } } /// Returns whether this `Field`'s dictionary is ordered, if this is a dictionary type. + /// + /// # Example + /// ``` + /// # use arrow_schema::{DataType, Field}; + /// // non dictionaries do not have a dict is ordered flat + /// let field = Field::new("c1", DataType::Int64, false); + /// assert_eq!(field.dict_is_ordered(), None); + /// // by default dictionary is not ordered + /// let field = Field::new("c1", DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8)), false); + /// assert_eq!(field.dict_is_ordered(), Some(false)); + /// let field = field.with_dict_is_ordered(true); + /// assert_eq!(field.dict_is_ordered(), Some(true)); + /// ``` #[inline] pub const fn dict_is_ordered(&self) -> Option { match self.data_type { @@ -410,6 +447,18 @@ impl Field { } } + /// Set the is ordered field for this `Field`, if it is a dictionary. + /// + /// Does nothing if this is not a dictionary type. + /// + /// See [`Field::dict_is_ordered`] for more information. + pub fn with_dict_is_ordered(mut self, dict_is_ordered: bool) -> Self { + if matches!(self.data_type, DataType::Dictionary(_, _)) { + self.dict_is_ordered = dict_is_ordered; + }; + self + } + /// Merge this field into self if it is compatible. /// /// Struct fields are merged recursively. @@ -425,6 +474,7 @@ impl Field { /// assert!(field.is_nullable()); /// ``` pub fn try_merge(&mut self, from: &Field) -> Result<(), ArrowError> { + #[allow(deprecated)] if from.dict_id != self.dict_id { return Err(ArrowError::SchemaError(format!( "Fail to merge schema field '{}' because from dict_id = {} does not match {}", @@ -567,9 +617,11 @@ impl Field { /// * self.metadata is a superset of other.metadata /// * all other fields are equal pub fn contains(&self, other: &Field) -> bool { + #[allow(deprecated)] + let matching_dict_id = self.dict_id == other.dict_id; self.name == other.name && self.data_type.contains(&other.data_type) - && self.dict_id == other.dict_id + && matching_dict_id && self.dict_is_ordered == other.dict_is_ordered // self need to be nullable or both of them are not nullable && (self.nullable || !other.nullable) @@ -618,6 +670,7 @@ mod test { fn test_new_dict_with_string() { // Fields should allow owned Strings to support reuse let s = "c1"; + #[allow(deprecated)] Field::new_dict(s, DataType::Int64, false, 4, false); } @@ -735,6 +788,7 @@ mod test { #[test] fn test_fields_with_dict_id() { + #[allow(deprecated)] let dict1 = Field::new_dict( "dict1", DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), @@ -742,6 +796,7 @@ mod test { 10, false, ); + #[allow(deprecated)] let dict2 = Field::new_dict( "dict2", DataType::Dictionary(DataType::Int32.into(), DataType::Int8.into()), @@ -778,9 +833,11 @@ mod test { false, ); + #[allow(deprecated)] for field in field.fields_with_dict_id(10) { assert_eq!(dict1, *field); } + #[allow(deprecated)] for field in field.fields_with_dict_id(20) { assert_eq!(dict2, *field); } @@ -795,6 +852,7 @@ mod test { #[test] fn test_field_comparison_case() { // dictionary-encoding properties not used for field comparison + #[allow(deprecated)] let dict1 = Field::new_dict( "dict1", DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), @@ -802,6 +860,7 @@ mod test { 10, false, ); + #[allow(deprecated)] let dict2 = Field::new_dict( "dict1", DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), @@ -813,6 +872,7 @@ mod test { assert_eq!(dict1, dict2); assert_eq!(get_field_hash(&dict1), get_field_hash(&dict2)); + #[allow(deprecated)] let dict1 = Field::new_dict( "dict0", DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index 5b9ce2a6da61..904b933cd299 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -18,7 +18,7 @@ use std::ops::Deref; use std::sync::Arc; -use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder}; +use crate::{ArrowError, DataType, Field, FieldRef}; /// A cheaply cloneable, owned slice of [`FieldRef`] /// @@ -256,33 +256,6 @@ impl Fields { .collect(); Ok(filtered) } - - /// Remove a field by index and return it. - /// - /// # Panic - /// - /// Panics if `index` is out of bounds. - /// - /// # Example - /// ``` - /// use arrow_schema::{DataType, Field, Fields}; - /// let mut fields = Fields::from(vec![ - /// Field::new("a", DataType::Boolean, false), - /// Field::new("b", DataType::Int8, false), - /// Field::new("c", DataType::Utf8, false), - /// ]); - /// assert_eq!(fields.len(), 3); - /// assert_eq!(fields.remove(1), Field::new("b", DataType::Int8, false).into()); - /// assert_eq!(fields.len(), 2); - /// ``` - #[deprecated(note = "Use SchemaBuilder::remove")] - #[doc(hidden)] - pub fn remove(&mut self, index: usize) -> FieldRef { - let mut builder = SchemaBuilder::from(Fields::from(&*self.0)); - let field = builder.remove(index); - *self = builder.finish().fields; - field - } } impl Default for Fields { @@ -496,7 +469,12 @@ mod tests { Field::new("floats", DataType::Struct(floats.clone()), true), true, ), - Field::new_fixed_size_list("f", Field::new("item", DataType::Int32, false), 3, false), + Field::new_fixed_size_list( + "f", + Field::new_list_field(DataType::Int32, false), + 3, + false, + ), Field::new_map( "g", "entries", diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs index cc3a8a308a83..6c79da53f981 100644 --- a/arrow-schema/src/schema.rs +++ b/arrow-schema/src/schema.rs @@ -389,7 +389,12 @@ impl Schema { /// Returns a vector of immutable references to all [`Field`] instances selected by /// the dictionary ID they use. + #[deprecated( + since = "54.0.0", + note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it." + )] pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { + #[allow(deprecated)] self.fields .iter() .flat_map(|f| f.fields_with_dict_id(dict_id)) @@ -434,33 +439,6 @@ impl Schema { .iter() .all(|(k, v1)| self.metadata.get(k).map(|v2| v1 == v2).unwrap_or_default()) } - - /// Remove field by index and return it. Recommend to use [`SchemaBuilder`] - /// if you are looking to remove multiple columns, as this will save allocations. - /// - /// # Panic - /// - /// Panics if `index` is out of bounds. - /// - /// # Example - /// - /// ``` - /// use arrow_schema::{DataType, Field, Schema}; - /// let mut schema = Schema::new(vec![ - /// Field::new("a", DataType::Boolean, false), - /// Field::new("b", DataType::Int8, false), - /// Field::new("c", DataType::Utf8, false), - /// ]); - /// assert_eq!(schema.fields.len(), 3); - /// assert_eq!(schema.remove(1), Field::new("b", DataType::Int8, false).into()); - /// assert_eq!(schema.fields.len(), 2); - /// ``` - #[deprecated(note = "Use SchemaBuilder::remove")] - #[doc(hidden)] - #[allow(deprecated)] - pub fn remove(&mut self, index: usize) -> FieldRef { - self.fields.remove(index) - } } impl fmt::Display for Schema { @@ -665,7 +643,9 @@ mod tests { assert_eq!(first_name.name(), "first_name"); assert_eq!(first_name.data_type(), &DataType::Utf8); assert!(!first_name.is_nullable()); - assert_eq!(first_name.dict_id(), None); + #[allow(deprecated)] + let dict_id = first_name.dict_id(); + assert_eq!(dict_id, None); assert_eq!(first_name.dict_is_ordered(), None); let metadata = first_name.metadata(); @@ -682,7 +662,9 @@ mod tests { interests.data_type(), &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) ); - assert_eq!(interests.dict_id(), Some(123)); + #[allow(deprecated)] + let dict_id = interests.dict_id(); + assert_eq!(dict_id, Some(123)); assert_eq!(interests.dict_is_ordered(), Some(true)); } @@ -718,6 +700,7 @@ mod tests { fn schema_field_with_dict_id() { let schema = person_schema(); + #[allow(deprecated)] let fields_dict_123: Vec<_> = schema .fields_with_dict_id(123) .iter() @@ -725,7 +708,9 @@ mod tests { .collect(); assert_eq!(fields_dict_123, vec!["interests"]); - assert!(schema.fields_with_dict_id(456).is_empty()); + #[allow(deprecated)] + let is_empty = schema.fields_with_dict_id(456).is_empty(); + assert!(is_empty); } fn person_schema() -> Schema { @@ -745,6 +730,7 @@ mod tests { ])), false, ), + #[allow(deprecated)] Field::new_dict( "interests", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), diff --git a/arrow-select/LICENSE.txt b/arrow-select/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-select/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-select/NOTICE.txt b/arrow-select/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-select/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 129b90ee0470..4855e0087cc6 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -34,9 +34,9 @@ use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values} use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_data::transform::{Capacities, MutableArrayData}; -use arrow_schema::{ArrowError, DataType, SchemaRef}; +use arrow_schema::{ArrowError, DataType, FieldRef, SchemaRef}; use std::sync::Arc; fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { @@ -129,6 +129,54 @@ fn concat_dictionaries( Ok(Arc::new(array)) } +fn concat_lists( + arrays: &[&dyn Array], + field: &FieldRef, +) -> Result { + let mut output_len = 0; + let mut list_has_nulls = false; + + let lists = arrays + .iter() + .map(|x| x.as_list::()) + .inspect(|l| { + output_len += l.len(); + list_has_nulls |= l.null_count() != 0; + }) + .collect::>(); + + let lists_nulls = list_has_nulls.then(|| { + let mut nulls = BooleanBufferBuilder::new(output_len); + for l in &lists { + match l.nulls() { + Some(n) => nulls.append_buffer(n.inner()), + None => nulls.append_n(l.len(), true), + } + } + NullBuffer::new(nulls.finish()) + }); + + let values: Vec<&dyn Array> = lists + .iter() + .map(|x| x.values().as_ref()) + .collect::>(); + + let concatenated_values = concat(values.as_slice())?; + + // Merge value offsets from the lists + let value_offset_buffer = + OffsetBuffer::::from_lengths(lists.iter().flat_map(|x| x.offsets().lengths())); + + let array = GenericListArray::::try_new( + Arc::clone(field), + value_offset_buffer, + concatenated_values, + lists_nulls, + )?; + + Ok(Arc::new(array)) +} + macro_rules! dict_helper { ($t:ty, $arrays:expr) => { return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _) @@ -163,14 +211,20 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { "It is not possible to concatenate arrays of different data types.".to_string(), )); } - if let DataType::Dictionary(k, _) = d { - downcast_integer! { - k.as_ref() => (dict_helper, arrays), - _ => unreachable!("illegal dictionary key type {k}") - }; - } else { - let capacity = get_capacity(arrays, d); - concat_fallback(arrays, capacity) + + match d { + DataType::Dictionary(k, _) => { + downcast_integer! { + k.as_ref() => (dict_helper, arrays), + _ => unreachable!("illegal dictionary key type {k}") + } + } + DataType::List(field) => concat_lists::(arrays, field), + DataType::LargeList(field) => concat_lists::(arrays, field), + _ => { + let capacity = get_capacity(arrays, d); + concat_fallback(arrays, capacity) + } } } @@ -228,8 +282,9 @@ pub fn concat_batches<'a>( #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::StringDictionaryBuilder; + use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder}; use arrow_schema::{Field, Schema}; + use std::fmt::Debug; #[test] fn test_concat_empty_vec() { @@ -851,4 +906,118 @@ mod tests { assert_eq!(array.null_count(), 10); assert_eq!(array.logical_null_count(), 10); } + + #[test] + fn concat_dictionary_list_array_simple() { + let scalars = vec![ + create_single_row_list_of_dict(vec![Some("a")]), + create_single_row_list_of_dict(vec![Some("a")]), + create_single_row_list_of_dict(vec![Some("b")]), + ]; + + let arrays = scalars + .iter() + .map(|a| a as &(dyn Array)) + .collect::>(); + let concat_res = concat(arrays.as_slice()).unwrap(); + + let expected_list = create_list_of_dict(vec![ + // Row 1 + Some(vec![Some("a")]), + Some(vec![Some("a")]), + Some(vec![Some("b")]), + ]); + + let list = concat_res.as_list::(); + + // Assert that the list is equal to the expected list + list.iter().zip(expected_list.iter()).for_each(|(a, b)| { + assert_eq!(a, b); + }); + + assert_dictionary_has_unique_values::<_, StringArray>( + list.values().as_dictionary::(), + ); + } + + #[test] + fn concat_many_dictionary_list_arrays() { + let number_of_unique_values = 8; + let scalars = (0..80000) + .map(|i| { + create_single_row_list_of_dict(vec![Some( + (i % number_of_unique_values).to_string(), + )]) + }) + .collect::>(); + + let arrays = scalars + .iter() + .map(|a| a as &(dyn Array)) + .collect::>(); + let concat_res = concat(arrays.as_slice()).unwrap(); + + let expected_list = create_list_of_dict( + (0..80000) + .map(|i| Some(vec![Some((i % number_of_unique_values).to_string())])) + .collect::>(), + ); + + let list = concat_res.as_list::(); + + // Assert that the list is equal to the expected list + list.iter().zip(expected_list.iter()).for_each(|(a, b)| { + assert_eq!(a, b); + }); + + assert_dictionary_has_unique_values::<_, StringArray>( + list.values().as_dictionary::(), + ); + } + + fn create_single_row_list_of_dict( + list_items: Vec>>, + ) -> GenericListArray { + let rows = list_items.into_iter().map(Some).collect(); + + create_list_of_dict(vec![rows]) + } + + fn create_list_of_dict( + rows: Vec>>>>, + ) -> GenericListArray { + let mut builder = + GenericListBuilder::::new(StringDictionaryBuilder::::new()); + + for row in rows { + builder.append_option(row); + } + + builder.finish() + } + + fn assert_dictionary_has_unique_values<'a, K, V>(array: &'a DictionaryArray) + where + K: ArrowDictionaryKeyType, + V: Sync + Send + 'static, + &'a V: ArrayAccessor + IntoIterator, + + <&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord, + <&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord, + { + let dict = array.downcast_dict::().unwrap(); + let mut values = dict.values().into_iter().collect::>(); + + // remove duplicates must be sorted first so we can compare + values.sort(); + + let mut unique_values = values.clone(); + + unique_values.dedup(); + + assert_eq!( + values, unique_values, + "There are duplicates in the value list (the value list here is sorted which is only for the assertion)" + ); + } } diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 4c6a5c0668f1..c91732848653 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -431,17 +431,17 @@ where R::Native: AddAssign, { let run_ends: &RunEndBuffer = array.run_ends(); - let mut values_filter = BooleanBufferBuilder::new(run_ends.len()); let mut new_run_ends = vec![R::default_value(); run_ends.len()]; let mut start = 0u64; - let mut i = 0; + let mut j = 0; let mut count = R::default_value(); let filter_values = predicate.filter.values(); + let run_ends = run_ends.inner(); - for mut end in run_ends.inner().into_iter().map(|i| (*i).into() as u64) { + let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| { let mut keep = false; - + let mut end = run_ends[i].into() as u64; let difference = end.saturating_sub(filter_values.len() as u64); end -= difference; @@ -450,23 +450,18 @@ where count += R::Native::from(pred); keep |= pred } - // this is to avoid branching - new_run_ends[i] = count; - i += keep as usize; + new_run_ends[j] = count; + j += keep as usize; - values_filter.append(keep); start = end; - } - - new_run_ends.truncate(i); + keep + }) + .into(); - if values_filter.is_empty() { - new_run_ends.clear(); - } + new_run_ends.truncate(j); let values = array.values(); - let pred = BooleanArray::new(values_filter.finish(), None); let values = filter(&values, &pred)?; let run_ends = PrimitiveArray::::new(new_run_ends.into(), None); @@ -522,14 +517,14 @@ fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer { unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } } IterationStrategy::SlicesIterator => { - let mut builder = BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8)); + let mut builder = BooleanBufferBuilder::new(predicate.count); for (start, end) in SlicesIterator::new(&predicate.filter) { builder.append_packed_range(start + offset..end + offset, src) } builder.into() } IterationStrategy::Slices(slices) => { - let mut builder = BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8)); + let mut builder = BooleanBufferBuilder::new(predicate.count); for (start, end) in slices { builder.append_packed_range(*start + offset..*end + offset, src) } @@ -1336,7 +1331,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]); let list_data_type = - DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(4) .add_buffer(value_offsets) @@ -1360,7 +1355,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]); let list_data_type = - DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); let expected = ArrayData::builder(list_data_type) .len(2) .add_buffer(value_offsets) diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index a0520e969a6b..4a47017b79ab 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -265,6 +265,67 @@ fn interleave_fallback( Ok(make_array(array_data.freeze())) } +/// Interleave rows by index from multiple [`RecordBatch`] instances and return a new [`RecordBatch`]. +/// +/// This function will call [`interleave`] on each array of the [`RecordBatch`] instances and assemble a new [`RecordBatch`]. +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{StringArray, Int32Array, RecordBatch, UInt32Array}; +/// # use arrow_schema::{DataType, Field, Schema}; +/// # use arrow_select::interleave::interleave_record_batch; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Utf8, true), +/// ])); +/// +/// let batch1 = RecordBatch::try_new( +/// schema.clone(), +/// vec![ +/// Arc::new(Int32Array::from(vec![0, 1, 2])), +/// Arc::new(StringArray::from(vec!["a", "b", "c"])), +/// ], +/// ).unwrap(); +/// +/// let batch2 = RecordBatch::try_new( +/// schema.clone(), +/// vec![ +/// Arc::new(Int32Array::from(vec![3, 4, 5])), +/// Arc::new(StringArray::from(vec!["d", "e", "f"])), +/// ], +/// ).unwrap(); +/// +/// let indices = vec![(0, 1), (1, 2), (0, 0), (1, 1)]; +/// let interleaved = interleave_record_batch(&[&batch1, &batch2], &indices).unwrap(); +/// +/// let expected = RecordBatch::try_new( +/// schema, +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 5, 0, 4])), +/// Arc::new(StringArray::from(vec!["b", "f", "a", "e"])), +/// ], +/// ).unwrap(); +/// assert_eq!(interleaved, expected); +/// ``` +pub fn interleave_record_batch( + record_batches: &[&RecordBatch], + indices: &[(usize, usize)], +) -> Result { + let schema = record_batches[0].schema(); + let columns = (0..schema.fields().len()) + .map(|i| { + let column_values: Vec<&dyn Array> = record_batches + .iter() + .map(|batch| batch.column(i).as_ref()) + .collect(); + interleave(&column_values, indices) + }) + .collect::, _>>()?; + RecordBatch::try_new(schema, columns) +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 07630a49fa11..71a7c77a8f92 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -1606,7 +1606,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref(&value_offsets); // Construct a list array from the above two let list_data_type = - DataType::$list_data_type(Arc::new(Field::new("item", DataType::Int32, false))); + DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type.clone()) .len(4) .add_buffer(value_offsets) @@ -1672,7 +1672,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref(&value_offsets); // Construct a list array from the above two let list_data_type = - DataType::$list_data_type(Arc::new(Field::new("item", DataType::Int32, true))); + DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type.clone()) .len(4) .add_buffer(value_offsets) @@ -1739,7 +1739,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref(&value_offsets); // Construct a list array from the above two let list_data_type = - DataType::$list_data_type(Arc::new(Field::new("item", DataType::Int32, true))); + DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type.clone()) .len(4) .add_buffer(value_offsets) @@ -1904,7 +1904,8 @@ mod tests { // Construct offsets let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -2222,7 +2223,7 @@ mod tests { fn test_take_fixed_size_list_null_indices() { let indices = Int32Array::from_iter([Some(0), None]); let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3])); - let arr_field = Arc::new(Field::new("item", values.data_type().clone(), true)); + let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true)); let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap(); let r = take(&values, &indices, None).unwrap(); diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index acb31dfa3bc2..2efd2e749921 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -15,20 +15,72 @@ // specific language governing permissions and limitations // under the License. -//! Zip two arrays by some boolean mask. Where the mask evaluates `true` values of `truthy` +//! [`zip`]: Combine values from two arrays based on boolean mask use crate::filter::SlicesIterator; use arrow_array::*; use arrow_data::transform::MutableArrayData; use arrow_schema::ArrowError; -/// Zip two arrays by some boolean mask. Where the mask evaluates `true` values of `truthy` -/// are taken, where the mask evaluates `false` values of `falsy` are taken. +/// Zip two arrays by some boolean mask. /// -/// # Arguments -/// * `mask` - Boolean values used to determine from which array to take the values. -/// * `truthy` - Values of this array are taken if mask evaluates `true` -/// * `falsy` - Values of this array are taken if mask evaluates `false` +/// - Where `mask` is `true`, values of `truthy` are taken +/// - Where `mask` is `false` or `NULL`, values of `falsy` are taken +/// +/// # Example: `zip` two arrays +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array}; +/// # use arrow_select::zip::zip; +/// // mask: [true, true, false, NULL, true] +/// let mask = BooleanArray::from(vec![ +/// Some(true), Some(true), Some(false), None, Some(true) +/// ]); +/// // truthy array: [1, NULL, 3, 4, 5] +/// let truthy = Int32Array::from(vec![ +/// Some(1), None, Some(3), Some(4), Some(5) +/// ]); +/// // falsy array: [10, 20, 30, 40, 50] +/// let falsy = Int32Array::from(vec![ +/// Some(10), Some(20), Some(30), Some(40), Some(50) +/// ]); +/// // zip with this mask select the first, second and last value from `truthy` +/// // and the third and fourth value from `falsy` +/// let result = zip(&mask, &truthy, &falsy).unwrap(); +/// // Expected: [1, NULL, 30, 40, 5] +/// let expected: ArrayRef = Arc::new(Int32Array::from(vec![ +/// Some(1), None, Some(30), Some(40), Some(5) +/// ])); +/// assert_eq!(&result, &expected); +/// ``` +/// +/// # Example: `zip` and array with a scalar +/// +/// Use `zip` to replace certain values in an array with a scalar +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array}; +/// # use arrow_select::zip::zip; +/// // mask: [true, true, false, NULL, true] +/// let mask = BooleanArray::from(vec![ +/// Some(true), Some(true), Some(false), None, Some(true) +/// ]); +/// // array: [1, NULL, 3, 4, 5] +/// let arr = Int32Array::from(vec![ +/// Some(1), None, Some(3), Some(4), Some(5) +/// ]); +/// // scalar: 42 +/// let scalar = Int32Array::new_scalar(42); +/// // zip the array with the mask select the first, second and last value from `arr` +/// // and fill the third and fourth value with the scalar 42 +/// let result = zip(&mask, &arr, &scalar).unwrap(); +/// // Expected: [1, NULL, 42, 42, 5] +/// let expected: ArrayRef = Arc::new(Int32Array::from(vec![ +/// Some(1), None, Some(42), Some(42), Some(5) +/// ])); +/// assert_eq!(&result, &expected); +/// ``` pub fn zip( mask: &BooleanArray, truthy: &dyn Datum, diff --git a/arrow-string/LICENSE.txt b/arrow-string/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow-string/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow-string/NOTICE.txt b/arrow-string/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow-string/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow-string/src/length.rs b/arrow-string/src/length.rs index 6a28d44ea7aa..49fc244e72cc 100644 --- a/arrow-string/src/length.rs +++ b/arrow-string/src/length.rs @@ -710,7 +710,7 @@ mod tests { .build() .unwrap(); let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3); let nulls = NullBuffer::from(vec![true, false, true]); let list_data = ArrayData::builder(list_data_type) .len(3) diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 0a5aa77dbb95..e30e09146c6d 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -18,13 +18,16 @@ //! Provide SQL's LIKE operators for Arrow's string arrays use crate::predicate::Predicate; + use arrow_array::cast::AsArray; use arrow_array::*; use arrow_schema::*; use arrow_select::take::take; -use iterator::ArrayIter; + use std::sync::Arc; +pub use arrow_array::StringArrayType; + #[derive(Debug)] enum Op { Like(bool), @@ -150,39 +153,6 @@ fn like_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result: ArrayAccessor + Sized { - /// Returns true if all data within this string array is ASCII - fn is_ascii(&self) -> bool; - /// Constructs a new iterator - fn iter(&self) -> ArrayIter; -} - -impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { - fn is_ascii(&self) -> bool { - GenericStringArray::::is_ascii(self) - } - - fn iter(&self) -> ArrayIter { - GenericStringArray::::iter(self) - } -} -impl<'a> StringArrayType<'a> for &'a StringViewArray { - fn is_ascii(&self) -> bool { - StringViewArray::is_ascii(self) - } - - fn iter(&self) -> ArrayIter { - StringViewArray::iter(self) - } -} - fn apply<'a, T: StringArrayType<'a> + 'a>( op: Op, l: T, diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 5ad452a17b12..f3893cd5bd13 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -20,7 +20,9 @@ use crate::like::StringArrayType; -use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder}; +use arrow_array::builder::{ + BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder, +}; use arrow_array::cast::AsArray; use arrow_array::*; use arrow_buffer::NullBuffer; @@ -243,78 +245,96 @@ where Ok(BooleanArray::from(data)) } -fn regexp_array_match( - array: &GenericStringArray, - regex_array: &GenericStringArray, - flags_array: Option<&GenericStringArray>, -) -> Result { - let mut patterns: HashMap = HashMap::new(); - let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut list_builder = ListBuilder::new(builder); +macro_rules! process_regexp_array_match { + ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => { + let mut patterns: HashMap = HashMap::new(); - let complete_pattern = match flags_array { - Some(flags) => Box::new( - regex_array - .iter() - .zip(flags.iter()) - .map(|(pattern, flags)| { + let complete_pattern = match $flags_array { + Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { pattern.map(|pattern| match flags { Some(value) => format!("(?{value}){pattern}"), None => pattern.to_string(), }) - }), - ) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; + }, + )) as Box>>, + None => Box::new( + $regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - // Required for Postgres compatibility: - // SELECT regexp_match('foobarbequebaz', ''); = {""} - (Some(_), Some(pattern)) if pattern == *"" => { - list_builder.values().append_value(""); - list_builder.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - match re.captures(value) { - Some(caps) => { - let mut iter = caps.iter(); - if caps.len() > 1 { - iter.next(); - } - for m in iter.flatten() { - list_builder.values().append_value(m.as_str()); + $array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + $list_builder.values().append_value(""); + $list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + $list_builder.values().append_value(m.as_str()); + } - list_builder.append(true); + $list_builder.append(true); + } + None => $list_builder.append(false), } - None => list_builder.append(false), } + _ => $list_builder.append(false), } - _ => list_builder.append(false), - } - Ok(()) - }) - .collect::, ArrowError>>()?; + Ok(()) + }) + .collect::, ArrowError>>()?; + }; +} + +fn regexp_array_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> Result { + let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + process_regexp_array_match!(array, regex_array, flags_array, list_builder); + + Ok(Arc::new(list_builder.finish())) +} + +fn regexp_array_match_utf8view( + array: &StringViewArray, + regex_array: &StringViewArray, + flags_array: Option<&StringViewArray>, +) -> Result { + let builder = StringViewBuilder::with_capacity(0); + let mut list_builder = ListBuilder::new(builder); + + process_regexp_array_match!(array, regex_array, flags_array, list_builder); + Ok(Arc::new(list_builder.finish())) } @@ -333,6 +353,54 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( } } +fn get_scalar_pattern_flag_utf8view<'a>( + regex_array: &'a dyn Array, + flag_array: Option<&'a dyn Array>, +) -> (Option<&'a str>, Option<&'a str>) { + let regex = regex_array.as_string_view(); + let regex = regex.is_valid(0).then(|| regex.value(0)); + + if let Some(flag_array) = flag_array { + let flag = flag_array.as_string_view(); + (regex, flag.is_valid(0).then(|| flag.value(0))) + } else { + (regex, None) + } +} + +macro_rules! process_regexp_match { + ($array:expr, $regex:expr, $list_builder:expr) => { + $array + .iter() + .map(|value| { + match value { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + Some(_) if $regex.as_str().is_empty() => { + $list_builder.values().append_value(""); + $list_builder.append(true); + } + Some(value) => match $regex.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + $list_builder.values().append_value(m.as_str()); + } + $list_builder.append(true); + } + None => $list_builder.append(false), + }, + None => $list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()? + }; +} + fn regexp_scalar_match( array: &GenericStringArray, regex: &Regex, @@ -340,35 +408,19 @@ fn regexp_scalar_match( let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); let mut list_builder = ListBuilder::new(builder); - array - .iter() - .map(|value| { - match value { - // Required for Postgres compatibility: - // SELECT regexp_match('foobarbequebaz', ''); = {""} - Some(_) if regex.as_str() == "" => { - list_builder.values().append_value(""); - list_builder.append(true); - } - Some(value) => match regex.captures(value) { - Some(caps) => { - let mut iter = caps.iter(); - if caps.len() > 1 { - iter.next(); - } - for m in iter.flatten() { - list_builder.values().append_value(m.as_str()); - } + process_regexp_match!(array, regex, list_builder); - list_builder.append(true); - } - None => list_builder.append(false), - }, - _ => list_builder.append(false), - } - Ok(()) - }) - .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + +fn regexp_scalar_match_utf8view( + array: &StringViewArray, + regex: &Regex, +) -> Result { + let builder = StringViewBuilder::with_capacity(0); + let mut list_builder = ListBuilder::new(builder); + + process_regexp_match!(array, regex, list_builder); Ok(Arc::new(list_builder.finish())) } @@ -406,7 +458,7 @@ pub fn regexp_match( if array.data_type() != rhs.data_type() { return Err(ArrowError::ComputeError( - "regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8" + "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8" .to_string(), )); } @@ -428,7 +480,7 @@ pub fn regexp_match( if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { return Err(ArrowError::ComputeError( - "regexp_match() requires both pattern and flags to be either string or largestring" + "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8" .to_string(), )); } @@ -436,19 +488,20 @@ pub fn regexp_match( if is_rhs_scalar { // Regex and flag is scalars let (regex, flag) = match rhs.data_type() { + DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags), DataType::Utf8 => get_scalar_pattern_flag::(rhs, flags), DataType::LargeUtf8 => get_scalar_pattern_flag::(rhs, flags), _ => { return Err(ArrowError::ComputeError( - "regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(), + "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8" + .to_string(), )); } }; if regex.is_none() { return Ok(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", + &DataType::List(Arc::new(Field::new_list_field( array.data_type().clone(), true, ))), @@ -469,14 +522,21 @@ pub fn regexp_match( })?; match array.data_type() { + DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re), DataType::Utf8 => regexp_scalar_match(array.as_string::(), &re), DataType::LargeUtf8 => regexp_scalar_match(array.as_string::(), &re), _ => Err(ArrowError::ComputeError( - "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8" + .to_string(), )), } } else { match array.data_type() { + DataType::Utf8View => { + let regex_array = rhs.as_string_view(); + let flags_array = flags.map(|flags| flags.as_string_view()); + regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array) + } DataType::Utf8 => { let regex_array = rhs.as_string(); let flags_array = flags.map(|flags| flags.as_string()); @@ -488,7 +548,8 @@ pub fn regexp_match( regexp_array_match(array.as_string::(), regex_array, flags_array) } _ => Err(ArrowError::ComputeError( - "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8" + .to_string(), )), } } @@ -498,114 +559,316 @@ pub fn regexp_match( mod tests { use super::*; - #[test] - fn match_single_group() { - let values = vec![ + macro_rules! test_match_single_group { + ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $arr_type = <$arr_type>::from($values); + let pattern: $arr_type = <$arr_type>::from($patterns); + + let actual = regexp_match(&array, &pattern, None).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; + } + + test_match_single_group!( + match_single_group_string, + vec![ Some("abc-005-def"), Some("X-7-5"), Some("X545"), None, Some("foobarbequebaz"), Some("foobarbequebaz"), - ]; - let array = StringArray::from(values); - let mut pattern_values = vec![r".*-(\d*)-.*"; 4]; - pattern_values.push(r"(bar)(bequ1e)"); - pattern_values.push(""); - let pattern = GenericStringArray::::from(pattern_values); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.values().append_value("005"); - expected_builder.append(true); - expected_builder.values().append_value("7"); - expected_builder.append(true); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.values().append_value(""); - expected_builder.append(true); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); - } + ], + vec![ + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r"(bar)(bequ1e)", + "" + ], + StringArray, + GenericStringBuilder, + [Some("005"), Some("7"), None, None, None, Some("")] + ); + test_match_single_group!( + match_single_group_string_view, + vec![ + Some("abc-005-def"), + Some("X-7-5"), + Some("X545"), + None, + Some("foobarbequebaz"), + Some("foobarbequebaz"), + ], + vec![ + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r"(bar)(bequ1e)", + "" + ], + StringViewArray, + StringViewBuilder, + [Some("005"), Some("7"), None, None, None, Some("")] + ); + + macro_rules! test_match_single_group_with_flags { + ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + let pattern: $array_type = <$array_type>::from($patterns); + let flags: $array_type = <$array_type>::from($flags); + + let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); - #[test] - fn match_single_group_with_flags() { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]); - let flags = StringArray::from(vec!["i"; 4]); - let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false); - expected_builder.values().append_value("7"); - expected_builder.append(true); - expected_builder.append(false); - expected_builder.append(false); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => { + expected_builder.append(false); + } + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } - #[test] - fn match_scalar_pattern() { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1])); - let flags = Scalar::new(StringArray::from(vec!["i"; 1])); - let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false); - expected_builder.values().append_value("7"); - expected_builder.append(true); - expected_builder.append(false); - expected_builder.append(false); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); - - // No flag - let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); + test_match_single_group_with_flags!( + match_single_group_with_flags_string, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + vec![r"x.*-(\d*)-.*"; 4], + vec!["i"; 4], + StringArray, + GenericStringBuilder, + [None, Some("7"), None, None] + ); + test_match_single_group_with_flags!( + match_single_group_with_flags_stringview, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + vec![r"x.*-(\d*)-.*"; 4], + vec!["i"; 4], + StringViewArray, + StringViewBuilder, + [None, Some("7"), None, None] + ); + + macro_rules! test_match_scalar_pattern { + ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + + let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1])); + let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1])); + + let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } - #[test] - fn match_scalar_no_pattern() { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1)); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.append(false); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); + test_match_scalar_pattern!( + match_scalar_pattern_string_with_flags, + vec![ + Some("abc-005-def"), + Some("x-7-5"), + Some("X-0-Y"), + Some("X545"), + None + ], + r"x.*-(\d*)-.*", + Some("i"), + StringArray, + GenericStringBuilder, + [None, Some("7"), Some("0"), None, None] + ); + test_match_scalar_pattern!( + match_scalar_pattern_stringview_with_flags, + vec![ + Some("abc-005-def"), + Some("x-7-5"), + Some("X-0-Y"), + Some("X545"), + None + ], + r"x.*-(\d*)-.*", + Some("i"), + StringViewArray, + StringViewBuilder, + [None, Some("7"), Some("0"), None, None] + ); + + test_match_scalar_pattern!( + match_scalar_pattern_string_no_flags, + vec![ + Some("abc-005-def"), + Some("x-7-5"), + Some("X-0-Y"), + Some("X545"), + None + ], + r"x.*-(\d*)-.*", + None::<&str>, + StringArray, + GenericStringBuilder, + [None, Some("7"), None, None, None] + ); + test_match_scalar_pattern!( + match_scalar_pattern_stringview_no_flags, + vec![ + Some("abc-005-def"), + Some("x-7-5"), + Some("X-0-Y"), + Some("X545"), + None + ], + r"x.*-(\d*)-.*", + None::<&str>, + StringViewArray, + StringViewBuilder, + [None, Some("7"), None, None, None] + ); + + macro_rules! test_match_scalar_no_pattern { + ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + let pattern = Scalar::new(new_null_array(&$pattern_type, 1)); + + let actual = regexp_match(&array, &pattern, None).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } - #[test] - fn test_single_group_not_skip_match() { - let array = StringArray::from(vec![Some("foo"), Some("bar")]); - let pattern = GenericStringArray::::from(vec![r"foo"]); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let result = actual.as_any().downcast_ref::().unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.values().append_value("foo"); - expected_builder.append(true); - let expected = expected_builder.finish(); - assert_eq!(&expected, result); + test_match_scalar_no_pattern!( + match_scalar_no_pattern_string, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + StringArray, + DataType::Utf8, + GenericStringBuilder, + [None::<&str>, None, None, None] + ); + test_match_scalar_no_pattern!( + match_scalar_no_pattern_stringview, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + StringViewArray, + DataType::Utf8View, + StringViewBuilder, + [None::<&str>, None, None, None] + ); + + macro_rules! test_match_single_group_not_skip { + ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + let pattern: $array_type = <$array_type>::from(vec![$pattern]); + + let actual = regexp_match(&array, &pattern, None).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } + test_match_single_group_not_skip!( + match_single_group_not_skip_string, + vec![Some("foo"), Some("bar")], + r"foo", + StringArray, + GenericStringBuilder, + [Some("foo")] + ); + test_match_single_group_not_skip!( + match_single_group_not_skip_stringview, + vec![Some("foo"), Some("bar")], + r"foo", + StringViewArray, + StringViewBuilder, + [Some("foo")] + ); + macro_rules! test_flag_utf8 { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] diff --git a/arrow-string/src/substring.rs b/arrow-string/src/substring.rs index bfdafb790f39..fa6a47147521 100644 --- a/arrow-string/src/substring.rs +++ b/arrow-string/src/substring.rs @@ -636,7 +636,7 @@ mod tests { let data = ArrayData::builder(DataType::FixedSizeBinary(5)) .len(2) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(&values)) .offset(1) .null_bit_buffer(Some(Buffer::from(bits_v))) .build() diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index a0fd96415a1d..8860cd61c5b3 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -54,9 +54,7 @@ arrow-select = { workspace = true } arrow-string = { workspace = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } -pyo3 = { version = "0.22.2", default-features = false, optional = true } - -chrono = { workspace = true, optional = true } +pyo3 = { version = "0.23", default-features = false, optional = true } [package.metadata.docs.rs] features = ["prettyprint", "ipc_compression", "ffi", "pyarrow"] @@ -72,7 +70,7 @@ prettyprint = ["arrow-cast/prettyprint"] # not the core arrow code itself. Be aware that `rand` must be kept as # an optional dependency for supporting compile to wasm32-unknown-unknown # target without assuming an environment containing JavaScript. -test_utils = ["rand", "dep:chrono"] +test_utils = ["dep:rand"] pyarrow = ["pyo3", "ffi"] # force_validate runs full data validation for all arrays that are created # this is not enabled by default as it is too computationally expensive @@ -87,7 +85,6 @@ chrono = { workspace = true } criterion = { version = "0.5", default-features = false } half = { version = "2.1", default-features = false } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } -tempfile = { version = "3", default-features = false } serde = { version = "1.0", default-features = false, features = ["derive"] } [build-dependencies] diff --git a/arrow/LICENSE.txt b/arrow/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/arrow/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/arrow/NOTICE.txt b/arrow/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/arrow/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/arrow/README.md b/arrow/README.md index 557a0b474e4b..79aefaae9053 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -25,7 +25,7 @@ This crate contains the official Native Rust implementation of [Apache Arrow][arrow] in memory format, governed by the Apache Software Foundation. The [API documentation](https://docs.rs/arrow/latest) contains examples and full API. -There are several [examples](https://github.com/apache/arrow-rs/tree/master/arrow/examples) to start from as well. +There are several [examples](https://github.com/apache/arrow-rs/tree/main/arrow/examples) to start from as well. The API documentation for most recent, unreleased code is available [here](https://arrow.apache.org/rust/arrow/index.html). @@ -37,7 +37,7 @@ This crate is tested with the latest stable version of Rust. We do not currently The `arrow` crate follows the [SemVer standard] defined by Cargo and works well within the Rust crate ecosystem. See the [repository README] for more details on -the release schedule and version. +the release schedule, version and deprecation policy. [SemVer standard]: https://doc.rust-lang.org/cargo/reference/semver.html [repository README]: https://github.com/apache/arrow-rs @@ -57,7 +57,7 @@ The `arrow` crate provides the following features which may be enabled in your ` - `ipc` (default) - support for reading [Arrow IPC Format](https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc), also used as the wire protocol in [arrow-flight](https://crates.io/crates/arrow-flight) - `ipc_compression` - Enables reading and writing compressed IPC streams (also enables `ipc`) - `prettyprint` - support for formatting record batches as textual columns - implementations of some [compute](https://github.com/apache/arrow-rs/tree/master/arrow/src/compute/kernels) + implementations of some [compute](https://github.com/apache/arrow-rs/tree/main/arrow/src/compute/kernels) - `chrono-tz` - support of parsing timezone using [chrono-tz](https://docs.rs/chrono-tz/0.6.0/chrono_tz/) - `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) - `pyarrow` - bindings for pyo3 to call arrow-rs from python diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs index ec7990d3d764..5c4fcff13dee 100644 --- a/arrow/benches/cast_kernels.rs +++ b/arrow/benches/cast_kernels.rs @@ -250,6 +250,9 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("cast decimal128 to decimal128 512", |b| { b.iter(|| cast_array(&decimal128_array, DataType::Decimal128(30, 5))) }); + c.bench_function("cast decimal128 to decimal128 512 lower precision", |b| { + b.iter(|| cast_array(&decimal128_array, DataType::Decimal128(6, 5))) + }); c.bench_function("cast decimal128 to decimal256 512", |b| { b.iter(|| cast_array(&decimal128_array, DataType::Decimal256(50, 5))) }); diff --git a/arrow/benches/concatenate_kernel.rs b/arrow/benches/concatenate_kernel.rs index 0c553f8b3f3c..034f5f2a305c 100644 --- a/arrow/benches/concatenate_kernel.rs +++ b/arrow/benches/concatenate_kernel.rs @@ -86,14 +86,14 @@ fn add_benchmark(c: &mut Criterion) { }); let v1 = FixedSizeListArray::try_new( - Arc::new(Field::new("item", DataType::Int32, true)), + Arc::new(Field::new_list_field(DataType::Int32, true)), 1024, Arc::new(create_primitive_array::(1024 * 1024, 0.0)), None, ) .unwrap(); let v2 = FixedSizeListArray::try_new( - Arc::new(Field::new("item", DataType::Int32, true)), + Arc::new(Field::new_list_field(DataType::Int32, true)), 1024, Arc::new(create_primitive_array::(1024 * 1024, 0.0)), None, diff --git a/arrow/benches/json_reader.rs b/arrow/benches/json_reader.rs index 8f3898c51f9d..c698a93fe869 100644 --- a/arrow/benches/json_reader.rs +++ b/arrow/benches/json_reader.rs @@ -102,22 +102,22 @@ fn small_bench_list(c: &mut Criterion) { let schema = Arc::new(Schema::new(vec![ Field::new( "c1", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), true, ), Field::new( "c2", - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))), true, ), Field::new( "c3", - DataType::List(Arc::new(Field::new("item", DataType::UInt32, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))), true, ), Field::new( "c4", - DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Boolean, true))), true, ), ])); diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs index cd952299df47..bb1c6081eaf9 100644 --- a/arrow/benches/lexsort.rs +++ b/arrow/benches/lexsort.rs @@ -83,7 +83,7 @@ impl Column { Column::RequiredI32List => { let field = Field::new( "_1", - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), true, ); create_random_array(&field, size, 0., 1.).unwrap() @@ -91,7 +91,7 @@ impl Column { Column::OptionalI32List => { let field = Field::new( "_1", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, ); create_random_array(&field, size, 0.2, 1.).unwrap() @@ -99,7 +99,7 @@ impl Column { Column::Required4CharStringList => { let field = Field::new( "_1", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, false))), true, ); create_random_array(&field, size, 0., 1.).unwrap() @@ -107,7 +107,7 @@ impl Column { Column::Optional4CharStringList => { let field = Field::new( "_1", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), true, ); create_random_array(&field, size, 0.2, 1.).unwrap() diff --git a/arrow/examples/builders.rs b/arrow/examples/builders.rs index 5c8cd51c55a0..8043ad82fca6 100644 --- a/arrow/examples/builders.rs +++ b/arrow/examples/builders.rs @@ -76,7 +76,7 @@ fn main() { let array_data = ArrayData::builder(DataType::Utf8) .len(3) .add_buffer(Buffer::from(offsets.to_byte_slice())) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(&values)) .null_bit_buffer(Some(Buffer::from([0b00000101]))) .build() .unwrap(); @@ -97,7 +97,7 @@ fn main() { let value_offsets = Buffer::from([0, 3, 6, 8].to_byte_slice()); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index 5002e5bf181a..7fc5acdc1b19 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -336,7 +336,7 @@ //! //! If you think you have found an instance where this is possible, please file //! a ticket in our [issue tracker] and it will be triaged and fixed. For more information on -//! arrow's use of unsafe, see [here](https://github.com/apache/arrow-rs/tree/master/arrow#safety). +//! arrow's use of unsafe, see [here](https://github.com/apache/arrow-rs/tree/main/arrow#safety). //! //! # Higher-level Processing //! diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 6effe1c03e01..4ccbd0541d3f 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -111,7 +111,7 @@ impl IntoPyArrow for T { } fn validate_class(expected: &str, value: &Bound) -> PyResult<()> { - let pyarrow = PyModule::import_bound(value.py(), "pyarrow")?; + let pyarrow = PyModule::import(value.py(), "pyarrow")?; let class = pyarrow.getattr(expected)?; if !value.is_instance(&class)? { let expected_module = class.getattr("__module__")?.extract::()?; @@ -177,7 +177,7 @@ impl ToPyArrow for DataType { fn to_pyarrow(&self, py: Python) -> PyResult { let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?; let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - let module = py.import_bound("pyarrow")?; + let module = py.import("pyarrow")?; let class = module.getattr("DataType")?; let dtype = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?; Ok(dtype.into()) @@ -213,7 +213,7 @@ impl ToPyArrow for Field { fn to_pyarrow(&self, py: Python) -> PyResult { let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?; let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - let module = py.import_bound("pyarrow")?; + let module = py.import("pyarrow")?; let class = module.getattr("Field")?; let dtype = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?; Ok(dtype.into()) @@ -249,7 +249,7 @@ impl ToPyArrow for Schema { fn to_pyarrow(&self, py: Python) -> PyResult { let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?; let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - let module = py.import_bound("pyarrow")?; + let module = py.import("pyarrow")?; let class = module.getattr("Schema")?; let schema = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?; Ok(schema.into()) @@ -309,7 +309,7 @@ impl ToPyArrow for ArrayData { let array = FFI_ArrowArray::new(self); let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?; - let module = py.import_bound("pyarrow")?; + let module = py.import("pyarrow")?; let class = module.getattr("Array")?; let array = class.call_method1( "_import_from_c", @@ -318,7 +318,7 @@ impl ToPyArrow for ArrayData { addr_of!(schema) as Py_uintptr_t, ), )?; - Ok(array.to_object(py)) + Ok(array.unbind()) } } @@ -335,7 +335,7 @@ impl ToPyArrow for Vec { .iter() .map(|v| v.to_pyarrow(py)) .collect::>>()?; - Ok(values.to_object(py)) + Ok(PyList::new(py, values)?.unbind().into()) } } @@ -451,7 +451,7 @@ impl FromPyArrow for ArrowArrayStreamReader { // make the conversion through PyArrow's private API // this changes the pointer's memory and is thus unsafe. // In particular, `_export_to_c` can go out of bounds - let args = PyTuple::new_bound(value.py(), [stream_ptr as Py_uintptr_t]); + let args = PyTuple::new(value.py(), [stream_ptr as Py_uintptr_t])?; value.call_method1("_export_to_c", args)?; let stream_reader = ArrowArrayStreamReader::try_new(stream) @@ -469,9 +469,9 @@ impl IntoPyArrow for Box { let mut stream = FFI_ArrowArrayStream::new(self); let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream; - let module = py.import_bound("pyarrow")?; + let module = py.import("pyarrow")?; let class = module.getattr("RecordBatchReader")?; - let args = PyTuple::new_bound(py, [stream_ptr as Py_uintptr_t]); + let args = PyTuple::new(py, [stream_ptr as Py_uintptr_t])?; let reader = class.call_method1("_import_from_c", args)?; Ok(PyObject::from(reader)) @@ -500,11 +500,17 @@ impl<'source, T: FromPyArrow> FromPyObject<'source> for PyArrowType { } } -impl IntoPy for PyArrowType { - fn into_py(self, py: Python) -> PyObject { +impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType { + type Target = PyAny; + + type Output = Bound<'py, Self::Target>; + + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { match self.0.into_pyarrow(py) { - Ok(obj) => obj, - Err(err) => err.to_object(py), + Ok(obj) => Result::Ok(obj.into_bound(py)), + Err(err) => Result::Err(err), } } } diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 56bbdefd522d..5f63812e51c0 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -538,7 +538,7 @@ mod tests { Field::new("a", DataType::Int32, false), Field::new( "b", - DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::LargeUtf8, false))), false, ), Field::new("a", DataType::Int32, false), @@ -569,10 +569,8 @@ mod tests { Field::new("b", DataType::Boolean, true), Field::new( "c", - DataType::LargeList(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new( - "item", + DataType::LargeList(Arc::new(Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field( DataType::FixedSizeBinary(6), true, ))), diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs index 8f86cbeab717..ef5ca6041700 100644 --- a/arrow/tests/array_cast.rs +++ b/arrow/tests/array_cast.rs @@ -315,7 +315,7 @@ fn make_fixed_size_list_array() -> FixedSizeListArray { // Construct a fixed size list array from the above two let list_data_type = - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 2); + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, true)), 2); let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data) @@ -325,11 +325,11 @@ fn make_fixed_size_list_array() -> FixedSizeListArray { } fn make_fixed_size_binary_array() -> FixedSizeBinaryArray { - let values: [u8; 15] = *b"hellotherearrow"; + let values: &[u8; 15] = b"hellotherearrow"; let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) .len(3) - .add_buffer(Buffer::from(&values[..])) + .add_buffer(Buffer::from(values)) .build() .unwrap(); FixedSizeBinaryArray::from(array_data) @@ -348,7 +348,7 @@ fn make_list_array() -> ListArray { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -371,7 +371,8 @@ fn make_large_list_array() -> LargeListArray { let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -466,12 +467,12 @@ fn get_all_types() -> Vec { LargeBinary, Utf8, LargeUtf8, - List(Arc::new(Field::new("item", DataType::Int8, true))), - List(Arc::new(Field::new("item", DataType::Utf8, true))), - FixedSizeList(Arc::new(Field::new("item", DataType::Int8, true)), 10), - FixedSizeList(Arc::new(Field::new("item", DataType::Utf8, false)), 10), - LargeList(Arc::new(Field::new("item", DataType::Int8, true))), - LargeList(Arc::new(Field::new("item", DataType::Utf8, false))), + List(Arc::new(Field::new_list_field(DataType::Int8, true))), + List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + FixedSizeList(Arc::new(Field::new_list_field(DataType::Int8, true)), 10), + FixedSizeList(Arc::new(Field::new_list_field(DataType::Utf8, false)), 10), + LargeList(Arc::new(Field::new_list_field(DataType::Int8, true))), + LargeList(Arc::new(Field::new_list_field(DataType::Utf8, false))), Struct(Fields::from(vec![ Field::new("f1", DataType::Int32, true), Field::new("f2", DataType::Utf8, true), diff --git a/arrow/tests/array_equal.rs b/arrow/tests/array_equal.rs index 7ed4dae1ed08..94fb85030bf3 100644 --- a/arrow/tests/array_equal.rs +++ b/arrow/tests/array_equal.rs @@ -409,8 +409,7 @@ fn test_empty_offsets_list_equal() { let values = Int32Array::from(empty); let empty_offsets: [u8; 0] = []; - let a: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( - "item", + let a: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, true, )))) @@ -422,8 +421,7 @@ fn test_empty_offsets_list_equal() { .unwrap() .into(); - let b: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( - "item", + let b: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, true, )))) @@ -437,8 +435,7 @@ fn test_empty_offsets_list_equal() { test_equal(&a, &b, true); - let c: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( - "item", + let c: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, true, )))) @@ -475,8 +472,7 @@ fn test_list_null() { // a list where the nullness of values is determined by the list's bitmap let c_values = Int32Array::from(vec![1, 2, -1, -2, 3, 4, -3, -4]); - let c: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( - "item", + let c: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, true, )))) @@ -498,8 +494,7 @@ fn test_list_null() { None, None, ]); - let d: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( - "item", + let d: ListArray = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, true, )))) diff --git a/arrow/tests/array_transform.rs b/arrow/tests/array_transform.rs index 08f23c200d52..c6de9f4a3417 100644 --- a/arrow/tests/array_transform.rs +++ b/arrow/tests/array_transform.rs @@ -600,7 +600,7 @@ fn test_list_append() { ]); let list_value_offsets = Buffer::from_slice_ref([0i32, 3, 5, 11, 13, 13, 15, 15, 17]); let expected_list_data = ArrayData::try_new( - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), 8, None, 0, @@ -677,7 +677,7 @@ fn test_list_nulls_append() { let list_value_offsets = Buffer::from_slice_ref([0, 3, 5, 5, 13, 15, 15, 15, 19, 19, 19, 19, 23]); let expected_list_data = ArrayData::try_new( - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), 12, Some(Buffer::from(&[0b11011011, 0b1110])), 0, @@ -940,7 +940,7 @@ fn test_list_of_strings_append() { ]); let list_value_offsets = Buffer::from_slice_ref([0, 3, 5, 6, 9, 10, 13]); let expected_list_data = ArrayData::try_new( - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), 6, None, 0, @@ -1141,7 +1141,7 @@ fn test_fixed_size_list_append() { Some(12), ]); let expected_fixed_size_list_data = ArrayData::try_new( - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::UInt16, true)), 2), + DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::UInt16, true)), 2), 12, Some(Buffer::from(&[0b11011101, 0b101])), 0, diff --git a/arrow/tests/shrink_to_fit.rs b/arrow/tests/shrink_to_fit.rs new file mode 100644 index 000000000000..5d7c2cf98bc9 --- /dev/null +++ b/arrow/tests/shrink_to_fit.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Array, ArrayRef, ListArray, PrimitiveArray}, + buffer::OffsetBuffer, + datatypes::{Field, UInt8Type}, +}; + +/// Test that `shrink_to_fit` frees memory after concatenating a large number of arrays. +#[test] +fn test_shrink_to_fit_after_concat() { + let array_len = 6_000; + let num_concats = 100; + + let primitive_array: PrimitiveArray = (0..array_len) + .map(|v| (v % 255) as u8) + .collect::>() + .into(); + let primitive_array: ArrayRef = Arc::new(primitive_array); + + let list_array: ArrayRef = Arc::new(ListArray::new( + Field::new_list_field(primitive_array.data_type().clone(), false).into(), + OffsetBuffer::from_lengths([primitive_array.len()]), + primitive_array.clone(), + None, + )); + + // Num bytes allocated globally and by this thread, respectively. + let (concatenated, _bytes_allocated_globally, bytes_allocated_by_this_thread) = + memory_use(|| { + let mut concatenated = concatenate(num_concats, list_array.clone()); + concatenated.shrink_to_fit(); // This is what we're testing! + dbg!(concatenated.data_type()); + concatenated + }); + let expected_len = num_concats * array_len; + assert_eq!(bytes_used(concatenated.clone()), expected_len); + eprintln!("The concatenated array is {expected_len} B long. Amount of memory used by this thread: {bytes_allocated_by_this_thread} B"); + + assert!( + expected_len <= bytes_allocated_by_this_thread, + "We must allocate at least as much space as the concatenated array" + ); + assert!( + bytes_allocated_by_this_thread <= expected_len + expected_len / 100, + "We shouldn't have more than 1% memory overhead. In fact, we are using {bytes_allocated_by_this_thread} B of memory for {expected_len} B of data" + ); +} + +fn concatenate(num_times: usize, array: ArrayRef) -> ArrayRef { + let mut concatenated = array.clone(); + for _ in 0..num_times - 1 { + concatenated = arrow::compute::kernels::concat::concat(&[&*concatenated, &*array]).unwrap(); + } + concatenated +} + +fn bytes_used(array: ArrayRef) -> usize { + let mut array = array; + loop { + match array.data_type() { + arrow::datatypes::DataType::UInt8 => break, + arrow::datatypes::DataType::List(_) => { + let list = array.as_any().downcast_ref::().unwrap(); + array = list.values().clone(); + } + _ => unreachable!(), + } + } + + array.len() +} + +// --- Memory tracking --- + +use std::{ + alloc::Layout, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +static LIVE_BYTES_GLOBAL: AtomicUsize = AtomicUsize::new(0); + +thread_local! { + static LIVE_BYTES_IN_THREAD: AtomicUsize = const { AtomicUsize::new(0) } ; +} + +pub struct TrackingAllocator { + allocator: std::alloc::System, +} + +#[global_allocator] +pub static GLOBAL_ALLOCATOR: TrackingAllocator = TrackingAllocator { + allocator: std::alloc::System, +}; + +#[allow(unsafe_code)] +// SAFETY: +// We just do book-keeping and then let another allocator do all the actual work. +unsafe impl std::alloc::GlobalAlloc for TrackingAllocator { + #[allow(clippy::let_and_return)] + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + // SAFETY: + // Just deferring + let ptr = unsafe { self.allocator.alloc(layout) }; + if !ptr.is_null() { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_add(layout.size(), Relaxed)); + LIVE_BYTES_GLOBAL.fetch_add(layout.size(), Relaxed); + } + ptr + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_sub(layout.size(), Relaxed)); + LIVE_BYTES_GLOBAL.fetch_sub(layout.size(), Relaxed); + + // SAFETY: + // Just deferring + unsafe { self.allocator.dealloc(ptr, layout) }; + } + + // No need to override `alloc_zeroed` or `realloc`, + // since they both by default just defer to `alloc` and `dealloc`. +} + +fn live_bytes_local() -> usize { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.load(Relaxed)) +} + +fn live_bytes_global() -> usize { + LIVE_BYTES_GLOBAL.load(Relaxed) +} + +/// Returns `(num_bytes_allocated, num_bytes_allocated_by_this_thread)`. +fn memory_use(run: impl Fn() -> R) -> (R, usize, usize) { + let used_bytes_start_local = live_bytes_local(); + let used_bytes_start_global = live_bytes_global(); + let ret = run(); + let bytes_used_local = live_bytes_local() - used_bytes_start_local; + let bytes_used_global = live_bytes_global() - used_bytes_start_global; + (ret, bytes_used_global, bytes_used_local) +} diff --git a/dev/release/README.md b/dev/release/README.md index d2d9e48bbb6b..6e6817bffb12 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -27,7 +27,7 @@ This file documents the release process for the "Rust Arrow Crates": `arrow`, `a The Rust Arrow Crates are interconnected (e.g. `parquet` has an optional dependency on `arrow`) so we increment and release all of them together. -If any code has been merged to master that has a breaking API change, as defined +If any code has been merged to main that has a breaking API change, as defined in [Rust RFC 1105] he major version number is incremented (e.g. `9.0.2` to `10.0.2`). Otherwise the new minor version incremented (e.g. `9.0.2` to `9.1.0`). @@ -46,19 +46,19 @@ crates.io, the Rust ecosystem's package manager. We create a `CHANGELOG.md` so our users know what has been changed between releases. The CHANGELOG is created automatically using -[update_change_log.sh](https://github.com/apache/arrow-rs/blob/master/dev/release/update_change_log.sh) +[update_change_log.sh](https://github.com/apache/arrow-rs/blob/main/dev/release/update_change_log.sh) This script creates a changelog using github issues and the labels associated with them. ## Prepare CHANGELOG and version: -Now prepare a PR to update `CHANGELOG.md` and versions on `master` to reflect the planned release. +Now prepare a PR to update `CHANGELOG.md` and versions on `main` to reflect the planned release. Do this in the root of this repository. For example [#2323](https://github.com/apache/arrow-rs/pull/2323) ```bash -git checkout master +git checkout main git pull git checkout -b @@ -72,6 +72,8 @@ export ARROW_GITHUB_API_TOKEN= # manually edit ./dev/release/update_change_log.sh to reflect the release version # create the changelog ./dev/release/update_change_log.sh +# commit the intial changes +git commit -a -m 'Create changelog' # run automated script to copy labels to issues based on referenced PRs # (NOTE 1: this must be done by a committer / other who has @@ -80,14 +82,12 @@ export ARROW_GITHUB_API_TOKEN= # NOTE 2: this must be done after creating the initial CHANGELOG file python dev/release/label_issues.py -# review change log / edit issues and labels if needed, rerun -git commit -a -m 'Create changelog' - -# Manually edit ./dev/release/update_change_log.sh to reflect the release version -# Create the changelog +# review change log / edit issues and labels if needed, rerun, repeat as necessary +# note you need to revert changes to CHANGELOG-old.md if you want to rerun the script CHANGELOG_GITHUB_TOKEN= ./dev/release/update_change_log.sh -# Review change log / edit issues and labels if needed, rerun -git commit -a -m 'Create changelog' + +# Commit the changes +git commit -a -m 'Update changelog' git push ``` @@ -96,7 +96,7 @@ Note that when reviewing the change log, rather than editing the `CHANGELOG.md`, it is preferred to update the issues and their labels (e.g. add `invalid` label to exclude them from release notes) -Merge this PR to `master` prior to the next step. +Merge this PR to `main` prior to the next step. ## Prepare release candidate tarball @@ -115,7 +115,7 @@ Create and push the tag thusly: ```shell git fetch apache -git tag apache/master +git tag apache/main # push tag to apache git push apache ``` diff --git a/dev/release/create-tarball.sh b/dev/release/create-tarball.sh index a77ddbe75701..8b92509104c8 100755 --- a/dev/release/create-tarball.sh +++ b/dev/release/create-tarball.sh @@ -109,7 +109,7 @@ The vote will be open for at least 72 hours. [1]: https://github.com/apache/arrow-rs/tree/${release_hash} [2]: ${url} [3]: https://github.com/apache/arrow-rs/blob/${release_hash}/CHANGELOG.md -[4]: https://github.com/apache/arrow-rs/blob/master/dev/release/verify-release-candidate.sh +[4]: https://github.com/apache/arrow-rs/blob/main/dev/release/verify-release-candidate.sh MAIL echo "---------------------------------------------------------" diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh index ab6460659d73..4a2f5e3f1987 100755 --- a/dev/release/update_change_log.sh +++ b/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="53.1.0" -FUTURE_RELEASE="53.2.0" +SINCE_TAG="53.3.0" +FUTURE_RELEASE="54.0.0" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" diff --git a/object_store/CHANGELOG-old.md b/object_store/CHANGELOG-old.md index 28dbde4e7b7f..c42689240dd9 100644 --- a/object_store/CHANGELOG-old.md +++ b/object_store/CHANGELOG-old.md @@ -19,6 +19,45 @@ # Historical Changelog + +## [object_store_0.11.1](https://github.com/apache/arrow-rs/tree/object_store_0.11.1) (2024-10-15) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.11.0...object_store_0.11.1) + +**Implemented enhancements:** + +- There is no way to pass object store client options as environment variables [\#6333](https://github.com/apache/arrow-rs/issues/6333) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Better Document Backoff Algorithm [\#6324](https://github.com/apache/arrow-rs/issues/6324) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add direction to `list_with_offset` [\#6274](https://github.com/apache/arrow-rs/issues/6274) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support server-side encryption with customer-provided keys \(SSE-C\) [\#6229](https://github.com/apache/arrow-rs/issues/6229) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- \[object-store\] Requested tokio version is too old - does not compile [\#6458](https://github.com/apache/arrow-rs/issues/6458) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Azure SAS tokens are visible when retry errors are logged via object\_store [\#6322](https://github.com/apache/arrow-rs/issues/6322) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- object\_store: fix typo in with\_connect\_timeout\_disabled that actually disabled non-connect timeouts [\#6563](https://github.com/apache/arrow-rs/pull/6563) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([adriangb](https://github.com/adriangb)) +- object\_store: Clarify what is a prefix in list\(\) documentation [\#6520](https://github.com/apache/arrow-rs/pull/6520) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([progval](https://github.com/progval)) +- object\_store: enable lint `unreachable_pub` [\#6512](https://github.com/apache/arrow-rs/pull/6512) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) +- \[object\_store\] Retry S3 requests with 200 response with "Error" in body [\#6508](https://github.com/apache/arrow-rs/pull/6508) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([PeterKeDer](https://github.com/PeterKeDer)) +- \[object-store\] Require tokio 1.29.0. [\#6459](https://github.com/apache/arrow-rs/pull/6459) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ashtuchkin](https://github.com/ashtuchkin)) +- feat: expose HTTP/2 max frame size in `object_store` [\#6442](https://github.com/apache/arrow-rs/pull/6442) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- Derive `Clone` for `object_store::aws::AmazonS3` [\#6414](https://github.com/apache/arrow-rs/pull/6414) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ethe](https://github.com/ethe)) +- object\_score: Support Azure Fabric OAuth Provider [\#6382](https://github.com/apache/arrow-rs/pull/6382) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([RobinLin666](https://github.com/RobinLin666)) +- `object_store::GetOptions` derive `Clone` [\#6361](https://github.com/apache/arrow-rs/pull/6361) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([samuelcolvin](https://github.com/samuelcolvin)) +- \[object\_store\] Propagate env vars as object store client options [\#6334](https://github.com/apache/arrow-rs/pull/6334) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ccciudatu](https://github.com/ccciudatu)) +- docs\[object\_store\]: clarify the backoff strategy that is actually implemented [\#6325](https://github.com/apache/arrow-rs/pull/6325) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([westonpace](https://github.com/westonpace)) +- fix: azure sas token visible in logs [\#6323](https://github.com/apache/arrow-rs/pull/6323) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- object\_store/delimited: Fix `TrailingEscape` condition [\#6265](https://github.com/apache/arrow-rs/pull/6265) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- fix\(object\_store\): only add encryption headers for SSE-C in get request [\#6260](https://github.com/apache/arrow-rs/pull/6260) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jiachengdb](https://github.com/jiachengdb)) +- docs: Add parquet\_opendal in related projects [\#6236](https://github.com/apache/arrow-rs/pull/6236) ([Xuanwo](https://github.com/Xuanwo)) +- feat\(object\_store\): add support for server-side encryption with customer-provided keys \(SSE-C\) [\#6230](https://github.com/apache/arrow-rs/pull/6230) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jiachengdb](https://github.com/jiachengdb)) +- feat: further TLS options on ClientOptions: \#5034 [\#6148](https://github.com/apache/arrow-rs/pull/6148) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) + + + ## [object_store_0.11.0](https://github.com/apache/arrow-rs/tree/object_store_0.11.0) (2024-08-12) [Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.10.2...object_store_0.11.0) diff --git a/object_store/CHANGELOG.md b/object_store/CHANGELOG.md index 95585983572c..0e834c5e2ef2 100644 --- a/object_store/CHANGELOG.md +++ b/object_store/CHANGELOG.md @@ -19,41 +19,42 @@ # Changelog -## [object_store_0.11.1](https://github.com/apache/arrow-rs/tree/object_store_0.11.1) (2024-10-15) +## [object_store_0.11.2](https://github.com/apache/arrow-rs/tree/object_store_0.11.2) (2024-12-20) -[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.11.0...object_store_0.11.1) +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.11.1...object_store_0.11.2) **Implemented enhancements:** -- There is no way to pass object store client options as environment variables [\#6333](https://github.com/apache/arrow-rs/issues/6333) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] -- Better Document Backoff Algorithm [\#6324](https://github.com/apache/arrow-rs/issues/6324) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] -- Add direction to `list_with_offset` [\#6274](https://github.com/apache/arrow-rs/issues/6274) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] -- Support server-side encryption with customer-provided keys \(SSE-C\) [\#6229](https://github.com/apache/arrow-rs/issues/6229) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object-store's AzureClient should protect against multiple streams performing put\_block in parallel for the same BLOB path [\#6868](https://github.com/apache/arrow-rs/issues/6868) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support S3 Put IfMatch [\#6799](https://github.com/apache/arrow-rs/issues/6799) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store Azure Government using OAuth [\#6759](https://github.com/apache/arrow-rs/issues/6759) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support for AWS Requester Pays buckets [\#6716](https://github.com/apache/arrow-rs/issues/6716) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object-store\]: Implement credential\_process support for S3 [\#6422](https://github.com/apache/arrow-rs/issues/6422) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Conditional put and rename\_if\_not\_exist on S3 [\#6285](https://github.com/apache/arrow-rs/issues/6285) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] **Fixed bugs:** -- \[object-store\] Requested tokio version is too old - does not compile [\#6458](https://github.com/apache/arrow-rs/issues/6458) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] -- Azure SAS tokens are visible when retry errors are logged via object\_store [\#6322](https://github.com/apache/arrow-rs/issues/6322) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- `object_store` errors when `reqwest` `gzip` feature is enabled [\#6842](https://github.com/apache/arrow-rs/issues/6842) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Multi-part s3 uploads fail when using checksum [\#6793](https://github.com/apache/arrow-rs/issues/6793) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- `with_unsigned_payload` shouldn't generate payload hash [\#6697](https://github.com/apache/arrow-rs/issues/6697) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[Object\_store\] min\_ttl is too high for GKE tokens [\#6625](https://github.com/apache/arrow-rs/issues/6625) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store `test_private_bucket` fails - store: "S3", source: BucketNotFound { bucket: "bloxbender" } [\#6600](https://github.com/apache/arrow-rs/issues/6600) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- S3 endpoint and trailing slash result in weird/invalid requests [\#6580](https://github.com/apache/arrow-rs/issues/6580) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] **Merged pull requests:** -- object\_store: fix typo in with\_connect\_timeout\_disabled that actually disabled non-connect timeouts [\#6563](https://github.com/apache/arrow-rs/pull/6563) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([adriangb](https://github.com/adriangb)) -- object\_store: Clarify what is a prefix in list\(\) documentation [\#6520](https://github.com/apache/arrow-rs/pull/6520) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([progval](https://github.com/progval)) -- object\_store: enable lint `unreachable_pub` [\#6512](https://github.com/apache/arrow-rs/pull/6512) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) -- \[object\_store\] Retry S3 requests with 200 response with "Error" in body [\#6508](https://github.com/apache/arrow-rs/pull/6508) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([PeterKeDer](https://github.com/PeterKeDer)) -- \[object-store\] Require tokio 1.29.0. [\#6459](https://github.com/apache/arrow-rs/pull/6459) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ashtuchkin](https://github.com/ashtuchkin)) -- feat: expose HTTP/2 max frame size in `object_store` [\#6442](https://github.com/apache/arrow-rs/pull/6442) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) -- Derive `Clone` for `object_store::aws::AmazonS3` [\#6414](https://github.com/apache/arrow-rs/pull/6414) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ethe](https://github.com/ethe)) -- object\_score: Support Azure Fabric OAuth Provider [\#6382](https://github.com/apache/arrow-rs/pull/6382) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([RobinLin666](https://github.com/RobinLin666)) -- `object_store::GetOptions` derive `Clone` [\#6361](https://github.com/apache/arrow-rs/pull/6361) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([samuelcolvin](https://github.com/samuelcolvin)) -- \[object\_store\] Propagate env vars as object store client options [\#6334](https://github.com/apache/arrow-rs/pull/6334) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ccciudatu](https://github.com/ccciudatu)) -- docs\[object\_store\]: clarify the backoff strategy that is actually implemented [\#6325](https://github.com/apache/arrow-rs/pull/6325) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([westonpace](https://github.com/westonpace)) -- fix: azure sas token visible in logs [\#6323](https://github.com/apache/arrow-rs/pull/6323) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) -- object\_store/delimited: Fix `TrailingEscape` condition [\#6265](https://github.com/apache/arrow-rs/pull/6265) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) -- fix\(object\_store\): only add encryption headers for SSE-C in get request [\#6260](https://github.com/apache/arrow-rs/pull/6260) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jiachengdb](https://github.com/jiachengdb)) -- docs: Add parquet\_opendal in related projects [\#6236](https://github.com/apache/arrow-rs/pull/6236) ([Xuanwo](https://github.com/Xuanwo)) -- feat\(object\_store\): add support for server-side encryption with customer-provided keys \(SSE-C\) [\#6230](https://github.com/apache/arrow-rs/pull/6230) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jiachengdb](https://github.com/jiachengdb)) -- feat: further TLS options on ClientOptions: \#5034 [\#6148](https://github.com/apache/arrow-rs/pull/6148) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) +- Use randomized content ID for Azure multipart uploads [\#6869](https://github.com/apache/arrow-rs/pull/6869) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([avarnon](https://github.com/avarnon)) +- Always explicitly disable `gzip` automatic decompression on reqwest client used by object\_store [\#6843](https://github.com/apache/arrow-rs/pull/6843) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([phillipleblanc](https://github.com/phillipleblanc)) +- object-store: remove S3ConditionalPut::ETagPutIfNotExists [\#6802](https://github.com/apache/arrow-rs/pull/6802) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([benesch](https://github.com/benesch)) +- Fix multipart uploads with checksums on object locked buckets [\#6794](https://github.com/apache/arrow-rs/pull/6794) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Add AuthorityHost to AzureConfigKey [\#6773](https://github.com/apache/arrow-rs/pull/6773) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([zadeluca](https://github.com/zadeluca)) +- object\_store: Add support for requester pays buckets [\#6768](https://github.com/apache/arrow-rs/pull/6768) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kylebarron](https://github.com/kylebarron)) +- check sign\_payload instead of skip\_signature before computing checksum [\#6698](https://github.com/apache/arrow-rs/pull/6698) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mherrerarendon](https://github.com/mherrerarendon)) +- Update quick-xml requirement from 0.36.0 to 0.37.0 in /object\_store [\#6687](https://github.com/apache/arrow-rs/pull/6687) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- Support native S3 conditional writes [\#6682](https://github.com/apache/arrow-rs/pull/6682) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([benesch](https://github.com/benesch)) +- \[object\_store\] fix S3 endpoint and trailing slash result in invalid requests [\#6641](https://github.com/apache/arrow-rs/pull/6641) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([adbmal](https://github.com/adbmal)) +- Lower GCP token min\_ttl to 4 minutes and add backoff to token refresh logic [\#6638](https://github.com/apache/arrow-rs/pull/6638) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mwylde](https://github.com/mwylde)) +- Remove `test_private_bucket` object\_store test [\#6601](https://github.com/apache/arrow-rs/pull/6601) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index 86d1392ebf61..992ae6662cdb 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -17,13 +17,13 @@ [package] name = "object_store" -version = "0.11.1" +version = "0.11.2" edition = "2021" license = "MIT/Apache-2.0" readme = "README.md" description = "A generic object store interface for uniformly interacting with AWS S3, Google Cloud Storage, Azure Blob Storage and local files." keywords = ["object", "storage", "cloud"] -repository = "https://github.com/apache/arrow-rs/tree/master/object_store" +repository = "https://github.com/apache/arrow-rs/tree/main/object_store" rust-version = "1.64.0" [package.metadata.docs.rs] @@ -35,13 +35,13 @@ bytes = "1.0" chrono = { version = "0.4.34", default-features = false, features = ["clock"] } futures = "0.3" humantime = "2.1" -itertools = "0.13.0" +itertools = "0.14.0" parking_lot = { version = "0.12" } percent-encoding = "2.1" -snafu = { version = "0.8", default-features = false, features = ["std", "rust_1_61"] } +thiserror = "2.0.2" tracing = { version = "0.1" } url = "2.2" -walkdir = "2" +walkdir = { version = "2", optional = true } # Cloud storage support base64 = { version = "0.22", default-features = false, features = ["std"], optional = true } @@ -55,13 +55,16 @@ ring = { version = "0.17", default-features = false, features = ["std"], optiona rustls-pemfile = { version = "2.0", default-features = false, features = ["std"], optional = true } tokio = { version = "1.29.0", features = ["sync", "macros", "rt", "time", "io-util"] } md-5 = { version = "0.10.6", default-features = false, optional = true } +httparse = { version = "1.8.0", default-features = false, features = ["std"], optional = true } [target.'cfg(target_family="unix")'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } [features] +default = ["fs"] cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "base64", "rand", "ring"] -azure = ["cloud"] +azure = ["cloud", "httparse"] +fs = ["walkdir"] gcp = ["cloud", "rustls-pemfile"] aws = ["cloud", "md-5"] http = ["cloud"] @@ -75,6 +78,10 @@ hyper-util = "0.1" http-body-util = "0.1" rand = "0.8" tempfile = "3.1.0" +regex = "1.11.1" +# The "gzip" feature for reqwest is enabled for an integration test. +reqwest = { version = "0.12", features = ["gzip"] } +http = "1.1.0" [[test]] name = "get_range_file" diff --git a/object_store/dev/release/README.md b/object_store/dev/release/README.md index 4077dcad9653..2dd1f6243c09 100644 --- a/object_store/dev/release/README.md +++ b/object_store/dev/release/README.md @@ -24,10 +24,13 @@ This file documents the release process for the `object_store` crate. -At the time of writing, we release a new version of `object_store` on demand rather than on a regular schedule. +We release a new version of `object_store` according to the schedule listed in +the [main README.md] + +[main README.md]: https://github.com/apache/arrow-rs?tab=readme-ov-file#object_store-crate As we are still in an early phase, we use the 0.x version scheme. If any code has -been merged to master that has a breaking API change, as defined in [Rust RFC 1105] +been merged to main that has a breaking API change, as defined in [Rust RFC 1105] the minor version number is incremented changed (e.g. `0.3.0` to `0.4.0`). Otherwise the patch version is incremented (e.g. `0.3.0` to `0.3.1`). @@ -45,14 +48,14 @@ crates.io, the Rust ecosystem's package manager. We create a `CHANGELOG.md` so our users know what has been changed between releases. The CHANGELOG is created automatically using -[update_change_log.sh](https://github.com/apache/arrow-rs/blob/master/object_store/dev/release/update_change_log.sh) +[update_change_log.sh](https://github.com/apache/arrow-rs/blob/main/object_store/dev/release/update_change_log.sh) This script creates a changelog using github issues and the labels associated with them. ## Prepare CHANGELOG and version: -Now prepare a PR to update `CHANGELOG.md` and versions on `master` to reflect the planned release. +Now prepare a PR to update `CHANGELOG.md` and versions on `main` to reflect the planned release. Note this process is done in the `object_store` directory. See [#6227] for an example @@ -62,7 +65,7 @@ Note this process is done in the `object_store` directory. See [#6227] for an e # NOTE: Run commands in object_store sub directory (not main repo checkout) # cd object_store -git checkout master +git checkout main git pull git checkout -b @@ -82,7 +85,7 @@ export CHANGELOG_GITHUB_TOKEN= # Commit changes git commit -a -m 'Create changelog' -# push changes to fork and create a PR to master +# push changes to fork and create a PR to main git push ``` @@ -90,7 +93,7 @@ Note that when reviewing the change log, rather than editing the `CHANGELOG.md`, it is preferred to update the issues and their labels (e.g. add `invalid` label to exclude them from release notes) -Merge this PR to `master` prior to the next step. +Merge this PR to `main` prior to the next step. ## Prepare release candidate tarball @@ -109,7 +112,7 @@ Create and push the tag thusly: ```shell git fetch apache -git tag apache/master +git tag apache/main # push tag to apache git push apache ``` @@ -170,7 +173,7 @@ The vote will be open for at least 72 hours. [1]: https://github.com/apache/arrow-rs/tree/b945b15de9085f5961a478d4f35b0c5c3427e248 [2]: https://dist.apache.org/repos/dist/dev/arrow/apache-arrow-object-store-rs-0.11.1-rc1 [3]: https://github.com/apache/arrow-rs/blob/b945b15de9085f5961a478d4f35b0c5c3427e248/object_store/CHANGELOG.md -[4]: https://github.com/apache/arrow-rs/blob/master/object_store/dev/release/verify-release-candidate.sh +[4]: https://github.com/apache/arrow-rs/blob/main/object_store/dev/release/verify-release-candidate.sh ``` For the release to become "official" it needs at least three Apache Arrow PMC members to vote +1 on it. diff --git a/object_store/dev/release/create-tarball.sh b/object_store/dev/release/create-tarball.sh index bbffde89b043..efc26fd0ef0f 100755 --- a/object_store/dev/release/create-tarball.sh +++ b/object_store/dev/release/create-tarball.sh @@ -101,7 +101,7 @@ The vote will be open for at least 72 hours. [1]: https://github.com/apache/arrow-rs/tree/${release_hash} [2]: ${url} [3]: https://github.com/apache/arrow-rs/blob/${release_hash}/object_store/CHANGELOG.md -[4]: https://github.com/apache/arrow-rs/blob/master/object_store/dev/release/verify-release-candidate.sh +[4]: https://github.com/apache/arrow-rs/blob/main/object_store/dev/release/verify-release-candidate.sh MAIL echo "---------------------------------------------------------" diff --git a/object_store/dev/release/update_change_log.sh b/object_store/dev/release/update_change_log.sh index 30724478ae1e..2797b62c0010 100755 --- a/object_store/dev/release/update_change_log.sh +++ b/object_store/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="object_store_0.11.0" -FUTURE_RELEASE="object_store_0.11.1" +SINCE_TAG="object_store_0.11.1" +FUTURE_RELEASE="object_store_0.11.2" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" diff --git a/object_store/src/aws/builder.rs b/object_store/src/aws/builder.rs index eb79f5e6dc28..d29fa782e8ff 100644 --- a/object_store/src/aws/builder.rs +++ b/object_store/src/aws/builder.rs @@ -32,7 +32,6 @@ use itertools::Itertools; use md5::{Digest, Md5}; use reqwest::header::{HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -43,46 +42,46 @@ use url::Url; static DEFAULT_METADATA_ENDPOINT: &str = "http://169.254.169.254"; /// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Missing bucket name"))] + #[error("Missing bucket name")] MissingBucketName, - #[snafu(display("Missing AccessKeyId"))] + #[error("Missing AccessKeyId")] MissingAccessKeyId, - #[snafu(display("Missing SecretAccessKey"))] + #[error("Missing SecretAccessKey")] MissingSecretAccessKey, - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] UnableToParseUrl { source: url::ParseError, url: String, }, - #[snafu(display( + #[error( "Unknown url scheme cannot be parsed into storage location: {}", scheme - ))] + )] UnknownUrlScheme { scheme: String }, - #[snafu(display("URL did not match any known pattern for scheme: {}", url))] + #[error("URL did not match any known pattern for scheme: {}", url)] UrlNotRecognised { url: String }, - #[snafu(display("Configuration key: '{}' is not known.", key))] + #[error("Configuration key: '{}' is not known.", key)] UnknownConfigurationKey { key: String }, - #[snafu(display("Invalid Zone suffix for bucket '{bucket}'"))] + #[error("Invalid Zone suffix for bucket '{bucket}'")] ZoneSuffix { bucket: String }, - #[snafu(display("Invalid encryption type: {}. Valid values are \"AES256\", \"sse:kms\", \"sse:kms:dsse\" and \"sse-c\".", passed))] + #[error("Invalid encryption type: {}. Valid values are \"AES256\", \"sse:kms\", \"sse:kms:dsse\" and \"sse-c\".", passed)] InvalidEncryptionType { passed: String }, - #[snafu(display( + #[error( "Invalid encryption header values. Header: {}, source: {}", header, source - ))] + )] InvalidEncryptionHeader { header: &'static str, source: Box, @@ -170,6 +169,8 @@ pub struct AmazonS3Builder { encryption_bucket_key_enabled: Option>, /// base64-encoded 256-bit customer encryption key for SSE-C. encryption_customer_key_base64: Option, + /// When set to true, charge requester for bucket operations + request_payer: ConfigValue, } /// Configuration keys for [`AmazonS3Builder`] @@ -330,6 +331,13 @@ pub enum AmazonS3ConfigKey { /// - `s3_express` S3Express, + /// Enable Support for S3 Requester Pays + /// + /// Supported keys: + /// - `aws_request_payer` + /// - `request_payer` + RequestPayer, + /// Client options Client(ClientConfigKey), @@ -358,6 +366,7 @@ impl AsRef for AmazonS3ConfigKey { Self::CopyIfNotExists => "aws_copy_if_not_exists", Self::ConditionalPut => "aws_conditional_put", Self::DisableTagging => "aws_disable_tagging", + Self::RequestPayer => "aws_request_payer", Self::Client(opt) => opt.as_ref(), Self::Encryption(opt) => opt.as_ref(), } @@ -389,6 +398,7 @@ impl FromStr for AmazonS3ConfigKey { "aws_copy_if_not_exists" | "copy_if_not_exists" => Ok(Self::CopyIfNotExists), "aws_conditional_put" | "conditional_put" => Ok(Self::ConditionalPut), "aws_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging), + "aws_request_payer" | "request_payer" => Ok(Self::RequestPayer), // Backwards compatibility "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), "aws_server_side_encryption" => Ok(Self::Encryption( @@ -510,6 +520,9 @@ impl AmazonS3Builder { AmazonS3ConfigKey::ConditionalPut => { self.conditional_put = Some(ConfigValue::Deferred(value.into())) } + AmazonS3ConfigKey::RequestPayer => { + self.request_payer = ConfigValue::Deferred(value.into()) + } AmazonS3ConfigKey::Encryption(key) => match key { S3EncryptionConfigKey::ServerSideEncryption => { self.encryption_type = Some(ConfigValue::Deferred(value.into())) @@ -567,6 +580,7 @@ impl AmazonS3Builder { self.conditional_put.as_ref().map(ToString::to_string) } AmazonS3ConfigKey::DisableTagging => Some(self.disable_tagging.to_string()), + AmazonS3ConfigKey::RequestPayer => Some(self.request_payer.to_string()), AmazonS3ConfigKey::Encryption(key) => match key { S3EncryptionConfigKey::ServerSideEncryption => { self.encryption_type.as_ref().map(ToString::to_string) @@ -588,8 +602,15 @@ impl AmazonS3Builder { /// This is a separate member function to allow fallible computation to /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] fn parse_url(&mut self, url: &str) -> Result<()> { - let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; - let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; + let parsed = Url::parse(url).map_err(|source| { + let url = url.into(); + Error::UnableToParseUrl { url, source } + })?; + + let host = parsed + .host_str() + .ok_or_else(|| Error::UrlNotRecognised { url: url.into() })?; + match parsed.scheme() { "s3" | "s3a" => self.bucket_name = Some(host.to_string()), "https" => match host.splitn(4, '.').collect_tuple() { @@ -615,9 +636,12 @@ impl AmazonS3Builder { self.bucket_name = Some(bucket.into()); } } - _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), + _ => return Err(Error::UrlNotRecognised { url: url.into() }.into()), }, - scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), + scheme => { + let scheme = scheme.into(); + return Err(Error::UnknownUrlScheme { scheme }.into()); + } }; Ok(()) } @@ -845,6 +869,14 @@ impl AmazonS3Builder { self } + /// Set whether to charge requester for bucket operations. + /// + /// + pub fn with_request_payer(mut self, enabled: bool) -> Self { + self.request_payer = ConfigValue::Parsed(enabled); + self + } + /// Create a [`AmazonS3`] instance from the provided values, /// consuming `self`. pub fn build(mut self) -> Result { @@ -852,7 +884,7 @@ impl AmazonS3Builder { self.parse_url(&url)?; } - let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; + let bucket = self.bucket_name.ok_or(Error::MissingBucketName)?; let region = self.region.unwrap_or_else(|| "us-east-1".to_string()); let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; let copy_if_not_exists = self.copy_if_not_exists.map(|x| x.get()).transpose()?; @@ -934,7 +966,10 @@ impl AmazonS3Builder { let (session_provider, zonal_endpoint) = match self.s3_express.get()? { true => { - let zone = parse_bucket_az(&bucket).context(ZoneSuffixSnafu { bucket: &bucket })?; + let zone = parse_bucket_az(&bucket).ok_or_else(|| { + let bucket = bucket.clone(); + Error::ZoneSuffix { bucket } + })?; // https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-express-Regions-and-Zones.html let endpoint = format!("https://{bucket}.s3express-{zone}.{region}.amazonaws.com"); @@ -996,6 +1031,7 @@ impl AmazonS3Builder { copy_if_not_exists, conditional_put: put_precondition, encryption_headers, + request_payer: self.request_payer.get()?, }; let client = Arc::new(S3Client::new(config)?); diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 5e80efd3388c..246f2779dd07 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -29,7 +29,7 @@ use crate::client::list::ListClient; use crate::client::retry::RetryExt; use crate::client::s3::{ CompleteMultipartUpload, CompleteMultipartUploadResult, CopyPartResult, - InitiateMultipartUploadResult, ListResponse, + InitiateMultipartUploadResult, ListResponse, PartMetadata, }; use crate::client::GetOptionsExt; use crate::multipart::PartId; @@ -56,64 +56,64 @@ use reqwest::{Client as ReqwestClient, Method, RequestBuilder, Response}; use ring::digest; use ring::digest::Context; use serde::{Deserialize, Serialize}; -use snafu::{ResultExt, Snafu}; use std::sync::Arc; const VERSION_HEADER: &str = "x-amz-version-id"; const SHA256_CHECKSUM: &str = "x-amz-checksum-sha256"; const USER_DEFINED_METADATA_HEADER_PREFIX: &str = "x-amz-meta-"; +const ALGORITHM: &str = "x-amz-checksum-algorithm"; /// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub(crate) enum Error { - #[snafu(display("Error performing DeleteObjects request: {}", source))] + #[error("Error performing DeleteObjects request: {}", source)] DeleteObjectsRequest { source: crate::client::retry::Error }, - #[snafu(display( + #[error( "DeleteObjects request failed for key {}: {} (code: {})", path, message, code - ))] + )] DeleteFailed { path: String, code: String, message: String, }, - #[snafu(display("Error getting DeleteObjects response body: {}", source))] + #[error("Error getting DeleteObjects response body: {}", source)] DeleteObjectsResponse { source: reqwest::Error }, - #[snafu(display("Got invalid DeleteObjects response: {}", source))] + #[error("Got invalid DeleteObjects response: {}", source)] InvalidDeleteObjectsResponse { source: Box, }, - #[snafu(display("Error performing list request: {}", source))] + #[error("Error performing list request: {}", source)] ListRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting list response body: {}", source))] + #[error("Error getting list response body: {}", source)] ListResponseBody { source: reqwest::Error }, - #[snafu(display("Error getting create multipart response body: {}", source))] + #[error("Error getting create multipart response body: {}", source)] CreateMultipartResponseBody { source: reqwest::Error }, - #[snafu(display("Error performing complete multipart request: {}: {}", path, source))] + #[error("Error performing complete multipart request: {}: {}", path, source)] CompleteMultipartRequest { source: crate::client::retry::Error, path: String, }, - #[snafu(display("Error getting complete multipart response body: {}", source))] + #[error("Error getting complete multipart response body: {}", source)] CompleteMultipartResponseBody { source: reqwest::Error }, - #[snafu(display("Got invalid list response: {}", source))] + #[error("Got invalid list response: {}", source)] InvalidListResponse { source: quick_xml::de::DeError }, - #[snafu(display("Got invalid multipart response: {}", source))] + #[error("Got invalid multipart response: {}", source)] InvalidMultipartResponse { source: quick_xml::de::DeError }, - #[snafu(display("Unable to extract metadata from headers: {}", source))] + #[error("Unable to extract metadata from headers: {}", source)] Metadata { source: crate::client::header::Error, }, @@ -202,6 +202,7 @@ pub(crate) struct S3Config { pub checksum: Option, pub copy_if_not_exists: Option, pub conditional_put: Option, + pub request_payer: bool, pub(super) encryption_headers: S3EncryptionHeaders, } @@ -245,11 +246,12 @@ struct SessionCredential<'a> { config: &'a S3Config, } -impl<'a> SessionCredential<'a> { +impl SessionCredential<'_> { fn authorizer(&self) -> Option> { let mut authorizer = AwsAuthorizer::new(self.credential.as_deref()?, "s3", &self.config.region) - .with_sign_payload(self.config.sign_payload); + .with_sign_payload(self.config.sign_payload) + .with_request_payer(self.config.request_payer); if self.session_token { let token = HeaderName::from_static("x-amz-s3session-token"); @@ -260,10 +262,15 @@ impl<'a> SessionCredential<'a> { } } -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub enum RequestError { - #[snafu(context(false))] - Generic { source: crate::Error }, + #[error(transparent)] + Generic { + #[from] + source: crate::Error, + }, + + #[error("Retry")] Retry { source: crate::client::retry::Error, path: String, @@ -288,10 +295,11 @@ pub(crate) struct Request<'a> { payload: Option, use_session_creds: bool, idempotent: bool, + retry_on_conflict: bool, retry_error_body: bool, } -impl<'a> Request<'a> { +impl Request<'_> { pub(crate) fn query(self, query: &T) -> Self { let builder = self.builder.query(query); Self { builder, ..self } @@ -315,6 +323,13 @@ impl<'a> Request<'a> { Self { idempotent, ..self } } + pub(crate) fn retry_on_conflict(self, retry_on_conflict: bool) -> Self { + Self { + retry_on_conflict, + ..self + } + } + pub(crate) fn retry_error_body(self, retry_error_body: bool) -> Self { Self { retry_error_body, @@ -380,10 +395,9 @@ impl<'a> Request<'a> { let payload_sha256 = sha256.finish(); if let Some(Checksum::SHA256) = self.config.checksum { - self.builder = self.builder.header( - "x-amz-checksum-sha256", - BASE64_STANDARD.encode(payload_sha256), - ); + self.builder = self + .builder + .header(SHA256_CHECKSUM, BASE64_STANDARD.encode(payload_sha256)); } self.payload_sha256 = Some(payload_sha256); } @@ -410,17 +424,22 @@ impl<'a> Request<'a> { self.builder .with_aws_sigv4(credential.authorizer(), sha) .retryable(&self.config.retry_config) + .retry_on_conflict(self.retry_on_conflict) .idempotent(self.idempotent) .retry_error_body(self.retry_error_body) .payload(self.payload) .send() .await - .context(RetrySnafu { path }) + .map_err(|source| { + let path = path.into(); + RequestError::Retry { source, path } + }) } pub(crate) async fn do_put(self) -> Result { let response = self.send().await?; - Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) } } @@ -446,6 +465,7 @@ impl S3Client { config: &self.config, use_session_creds: true, idempotent: false, + retry_on_conflict: false, retry_error_body: false, } } @@ -523,10 +543,10 @@ impl S3Client { .with_aws_sigv4(credential.authorizer(), Some(digest.as_ref())) .send_retry(&self.config.retry_config) .await - .context(DeleteObjectsRequestSnafu {})? + .map_err(|source| Error::DeleteObjectsRequest { source })? .bytes() .await - .context(DeleteObjectsResponseSnafu {})?; + .map_err(|source| Error::DeleteObjectsResponse { source })?; let response: BatchDeleteResponse = quick_xml::de::from_reader(response.reader()).map_err(|err| { @@ -605,8 +625,15 @@ impl S3Client { location: &Path, opts: PutMultipartOpts, ) -> Result { - let response = self - .request(Method::POST, location) + let mut request = self.request(Method::POST, location); + if let Some(algorithm) = self.config.checksum { + match algorithm { + Checksum::SHA256 => { + request = request.header(ALGORITHM, "SHA256"); + } + } + } + let response = request .query(&[("uploads", "")]) .with_encryption_headers() .with_attributes(opts.attributes) @@ -616,10 +643,10 @@ impl S3Client { .await? .bytes() .await - .context(CreateMultipartResponseBodySnafu)?; + .map_err(|source| Error::CreateMultipartResponseBody { source })?; - let response: InitiateMultipartUploadResult = - quick_xml::de::from_reader(response.reader()).context(InvalidMultipartResponseSnafu)?; + let response: InitiateMultipartUploadResult = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; Ok(response.upload_id) } @@ -657,19 +684,35 @@ impl S3Client { request = request.with_encryption_headers(); } let response = request.send().await?; - - let content_id = match is_copy { - false => get_etag(response.headers()).context(MetadataSnafu)?, + let checksum_sha256 = response + .headers() + .get(SHA256_CHECKSUM) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + + let e_tag = match is_copy { + false => get_etag(response.headers()).map_err(|source| Error::Metadata { source })?, true => { let response = response .bytes() .await - .context(CreateMultipartResponseBodySnafu)?; + .map_err(|source| Error::CreateMultipartResponseBody { source })?; let response: CopyPartResult = quick_xml::de::from_reader(response.reader()) - .context(InvalidMultipartResponseSnafu)?; + .map_err(|source| Error::InvalidMultipartResponse { source })?; response.e_tag } }; + + let content_id = if self.config.checksum == Some(Checksum::SHA256) { + let meta = PartMetadata { + e_tag, + checksum_sha256, + }; + quick_xml::se::to_string(&meta).unwrap() + } else { + e_tag + }; + Ok(PartId { content_id }) } @@ -729,19 +772,21 @@ impl S3Client { .retry_error_body(true) .send() .await - .context(CompleteMultipartRequestSnafu { - path: location.as_ref(), + .map_err(|source| Error::CompleteMultipartRequest { + source, + path: location.as_ref().to_string(), })?; - let version = get_version(response.headers(), VERSION_HEADER).context(MetadataSnafu)?; + let version = get_version(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?; let data = response .bytes() .await - .context(CompleteMultipartResponseBodySnafu)?; + .map_err(|source| Error::CompleteMultipartResponseBody { source })?; - let response: CompleteMultipartUploadResult = - quick_xml::de::from_reader(data.reader()).context(InvalidMultipartResponseSnafu)?; + let response: CompleteMultipartUploadResult = quick_xml::de::from_reader(data.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; Ok(PutResult { e_tag: Some(response.e_tag), @@ -849,13 +894,14 @@ impl ListClient for Arc { .with_aws_sigv4(credential.authorizer(), None) .send_retry(&self.config.retry_config) .await - .context(ListRequestSnafu)? + .map_err(|source| Error::ListRequest { source })? .bytes() .await - .context(ListResponseBodySnafu)?; + .map_err(|source| Error::ListResponseBody { source })?; + + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidListResponse { source })?; - let mut response: ListResponse = - quick_xml::de::from_reader(response.reader()).context(InvalidListResponseSnafu)?; let token = response.next_continuation_token.take(); Ok((response.try_into()?, token)) diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index 33972c6fa14a..9c74e1c6526a 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -29,23 +29,22 @@ use percent_encoding::utf8_percent_encode; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; use serde::Deserialize; -use snafu::{ResultExt, Snafu}; use std::collections::BTreeMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::warn; use url::Url; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] #[allow(clippy::enum_variant_names)] enum Error { - #[snafu(display("Error performing CreateSession request: {source}"))] + #[error("Error performing CreateSession request: {source}")] CreateSessionRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting CreateSession response: {source}"))] + #[error("Error getting CreateSession response: {source}")] CreateSessionResponse { source: reqwest::Error }, - #[snafu(display("Invalid CreateSessionOutput response: {source}"))] + #[error("Invalid CreateSessionOutput response: {source}")] CreateSessionOutput { source: quick_xml::DeError }, } @@ -101,11 +100,14 @@ pub struct AwsAuthorizer<'a> { region: &'a str, token_header: Option, sign_payload: bool, + request_payer: bool, } static DATE_HEADER: HeaderName = HeaderName::from_static("x-amz-date"); static HASH_HEADER: HeaderName = HeaderName::from_static("x-amz-content-sha256"); static TOKEN_HEADER: HeaderName = HeaderName::from_static("x-amz-security-token"); +static REQUEST_PAYER_HEADER: HeaderName = HeaderName::from_static("x-amz-request-payer"); +static REQUEST_PAYER_HEADER_VALUE: HeaderValue = HeaderValue::from_static("requester"); const ALGORITHM: &str = "AWS4-HMAC-SHA256"; impl<'a> AwsAuthorizer<'a> { @@ -118,6 +120,7 @@ impl<'a> AwsAuthorizer<'a> { date: None, sign_payload: true, token_header: None, + request_payer: false, } } @@ -134,6 +137,14 @@ impl<'a> AwsAuthorizer<'a> { self } + /// Set whether to include requester pays headers + /// + /// + pub fn with_request_payer(mut self, request_payer: bool) -> Self { + self.request_payer = request_payer; + self + } + /// Authorize `request` with an optional pre-calculated SHA256 digest by attaching /// the relevant [AWS SigV4] headers /// @@ -180,6 +191,15 @@ impl<'a> AwsAuthorizer<'a> { let header_digest = HeaderValue::from_str(&digest).unwrap(); request.headers_mut().insert(&HASH_HEADER, header_digest); + if self.request_payer { + // For DELETE, GET, HEAD, POST, and PUT requests, include x-amz-request-payer : + // requester in the header + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ObjectsinRequesterPaysBuckets.html + request + .headers_mut() + .insert(&REQUEST_PAYER_HEADER, REQUEST_PAYER_HEADER_VALUE.clone()); + } + let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); let scope = self.scope(date); @@ -226,6 +246,13 @@ impl<'a> AwsAuthorizer<'a> { .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string()) .append_pair("X-Amz-SignedHeaders", "host"); + if self.request_payer { + // For signed URLs, include x-amz-request-payer=requester in the request + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ObjectsinRequesterPaysBuckets.html + url.query_pairs_mut() + .append_pair("x-amz-request-payer", "requester"); + } + // For S3, you must include the X-Amz-Security-Token query parameter in the URL if // using credentials sourced from the STS service. if let Some(ref token) = self.credential.token { @@ -698,13 +725,13 @@ impl TokenProvider for SessionProvider { .with_aws_sigv4(Some(authorizer), None) .send_retry(retry) .await - .context(CreateSessionRequestSnafu)? + .map_err(|source| Error::CreateSessionRequest { source })? .bytes() .await - .context(CreateSessionResponseSnafu)?; + .map_err(|source| Error::CreateSessionResponse { source })?; - let resp: CreateSessionOutput = - quick_xml::de::from_reader(bytes.reader()).context(CreateSessionOutputSnafu)?; + let resp: CreateSessionOutput = quick_xml::de::from_reader(bytes.reader()) + .map_err(|source| Error::CreateSessionOutput { source })?; let creds = resp.credentials; Ok(TemporaryToken { @@ -763,12 +790,53 @@ mod tests { region: "us-east-1", sign_payload: true, token_header: None, + request_payer: false, }; signer.authorize(&mut request, None); assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4") } + #[test] + fn test_sign_with_signed_payload_request_payer() { + let client = Client::new(); + + // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + // method = 'GET' + // service = 'ec2' + // host = 'ec2.amazonaws.com' + // region = 'us-east-1' + // endpoint = 'https://ec2.amazonaws.com' + // request_parameters = '' + let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "https://ec2.amazon.com/") + .build() + .unwrap(); + + let signer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "ec2", + region: "us-east-1", + sign_payload: true, + token_header: None, + request_payer: true, + }; + + signer.authorize(&mut request, None); + assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-request-payer, Signature=7030625a9e9b57ed2a40e63d749f4a4b7714b6e15004cab026152f870dd8565d") + } + #[test] fn test_sign_with_unsigned_payload() { let client = Client::new(); @@ -802,6 +870,7 @@ mod tests { region: "us-east-1", token_header: None, sign_payload: false, + request_payer: false, }; authorizer.authorize(&mut request, None); @@ -828,6 +897,7 @@ mod tests { region: "us-east-1", token_header: None, sign_payload: false, + request_payer: false, }; let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); @@ -848,6 +918,48 @@ mod tests { ); } + #[test] + fn signed_get_url_request_payer() { + // Values from https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z") + .unwrap() + .with_timezone(&Utc); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "s3", + region: "us-east-1", + token_header: None, + sign_payload: false, + request_payer: true, + }; + + let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400)); + + assert_eq!( + url, + Url::parse( + "https://examplebucket.s3.amazonaws.com/test.txt?\ + X-Amz-Algorithm=AWS4-HMAC-SHA256&\ + X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\ + X-Amz-Date=20130524T000000Z&\ + X-Amz-Expires=86400&\ + X-Amz-SignedHeaders=host&\ + x-amz-request-payer=requester&\ + X-Amz-Signature=9ad7c781cc30121f199b47d35ed3528473e4375b63c5d91cd87c927803e4e00a" + ) + .unwrap() + ); + } + #[test] fn test_sign_port() { let client = Client::new(); @@ -880,6 +992,7 @@ mod tests { region: "us-east-1", token_header: None, sign_payload: true, + request_payer: false, }; authorizer.authorize(&mut request, None); diff --git a/object_store/src/aws/dynamo.rs b/object_store/src/aws/dynamo.rs index ece3b8a357c6..6283e76c1f87 100644 --- a/object_store/src/aws/dynamo.rs +++ b/object_store/src/aws/dynamo.rs @@ -471,7 +471,7 @@ enum ReturnValues { /// This provides cheap, ordered serialization of maps struct Map<'a, K, V>(&'a [(K, V)]); -impl<'a, K: Serialize, V: Serialize> Serialize for Map<'a, K, V> { +impl Serialize for Map<'_, K, V> { fn serialize(&self, serializer: S) -> Result where S: Serializer, diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 9330e5389138..82ef909de984 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -136,7 +136,8 @@ impl Signer for AmazonS3 { /// ``` async fn signed_url(&self, method: Method, path: &Path, expires_in: Duration) -> Result { let credential = self.credentials().get_credential().await?; - let authorizer = AwsAuthorizer::new(&credential, "s3", &self.client.config.region); + let authorizer = AwsAuthorizer::new(&credential, "s3", &self.client.config.region) + .with_request_payer(self.client.config.request_payer); let path_url = self.path_url(path); let mut url = Url::parse(&path_url).map_err(|e| crate::Error::Generic { @@ -169,10 +170,7 @@ impl ObjectStore for AmazonS3 { match (opts.mode, &self.client.config.conditional_put) { (PutMode::Overwrite, _) => request.idempotent(true).do_put().await, (PutMode::Create | PutMode::Update(_), None) => Err(Error::NotImplemented), - ( - PutMode::Create, - Some(S3ConditionalPut::ETagMatch | S3ConditionalPut::ETagPutIfNotExists), - ) => { + (PutMode::Create, Some(S3ConditionalPut::ETagMatch)) => { match request.header(&IF_NONE_MATCH, "*").do_put().await { // Technically If-None-Match should return NotModified but some stores, // such as R2, instead return PreconditionFailed @@ -196,9 +194,26 @@ impl ObjectStore for AmazonS3 { source: "ETag required for conditional put".to_string().into(), })?; match put { - S3ConditionalPut::ETagPutIfNotExists => Err(Error::NotImplemented), S3ConditionalPut::ETagMatch => { - request.header(&IF_MATCH, etag.as_str()).do_put().await + match request + .header(&IF_MATCH, etag.as_str()) + // Real S3 will occasionally report 409 Conflict + // if there are concurrent `If-Match` requests + // in flight, so we need to be prepared to retry + // 409 responses. + .retry_on_conflict(true) + .do_put() + .await + { + // Real S3 reports NotFound rather than PreconditionFailed when the + // object doesn't exist. Convert to PreconditionFailed for + // consistency with R2. This also matches what the HTTP spec + // says the behavior should be. + Err(Error::NotFound { path, source }) => { + Err(Error::Precondition { path, source }) + } + r => r, + } } S3ConditionalPut::Dynamo(d) => { d.conditional_op(&self.client, location, Some(&etag), move || { @@ -478,6 +493,66 @@ mod tests { const NON_EXISTENT_NAME: &str = "nonexistentname"; + #[tokio::test] + async fn write_multipart_file_with_signature() { + maybe_skip_integration!(); + + let store = AmazonS3Builder::from_env() + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + + let str = "test.bin"; + let path = Path::parse(str).unwrap(); + let opts = PutMultipartOpts::default(); + let mut upload = store.put_multipart_opts(&path, opts).await.unwrap(); + + upload + .put_part(PutPayload::from(vec![0u8; 10_000_000])) + .await + .unwrap(); + upload + .put_part(PutPayload::from(vec![0u8; 5_000_000])) + .await + .unwrap(); + + let res = upload.complete().await.unwrap(); + assert!(res.e_tag.is_some(), "Should have valid etag"); + + store.delete(&path).await.unwrap(); + } + + #[tokio::test] + async fn write_multipart_file_with_signature_object_lock() { + maybe_skip_integration!(); + + let bucket = "test-object-lock"; + let store = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + + let str = "test.bin"; + let path = Path::parse(str).unwrap(); + let opts = PutMultipartOpts::default(); + let mut upload = store.put_multipart_opts(&path, opts).await.unwrap(); + + upload + .put_part(PutPayload::from(vec![0u8; 10_000_000])) + .await + .unwrap(); + upload + .put_part(PutPayload::from(vec![0u8; 5_000_000])) + .await + .unwrap(); + + let res = upload.complete().await.unwrap(); + assert!(res.e_tag.is_some(), "Should have valid etag"); + + store.delete(&path).await.unwrap(); + } + #[tokio::test] async fn s3_test() { maybe_skip_integration!(); @@ -486,6 +561,7 @@ mod tests { let integration = config.build().unwrap(); let config = &integration.client.config; let test_not_exists = config.copy_if_not_exists.is_some(); + let test_conditional_put = config.conditional_put.is_some(); put_get_delete_list(&integration).await; get_opts(&integration).await; @@ -494,6 +570,7 @@ mod tests { rename_and_copy(&integration).await; stream_get(&integration).await; multipart(&integration, &integration).await; + multipart_race_condition(&integration, true).await; signing(&integration).await; s3_encryption(&integration).await; put_get_attributes(&integration).await; @@ -516,9 +593,8 @@ mod tests { if test_not_exists { copy_if_not_exists(&integration).await; } - if let Some(conditional_put) = &config.conditional_put { - let supports_update = !matches!(conditional_put, S3ConditionalPut::ETagPutIfNotExists); - put_opts(&integration, supports_update).await; + if test_conditional_put { + put_opts(&integration, true).await; } // run integration test with unsigned payload enabled diff --git a/object_store/src/aws/precondition.rs b/object_store/src/aws/precondition.rs index e5058052790d..b261ad0dbfb1 100644 --- a/object_store/src/aws/precondition.rs +++ b/object_store/src/aws/precondition.rs @@ -138,17 +138,6 @@ pub enum S3ConditionalPut { /// [HTTP precondition]: https://datatracker.ietf.org/doc/html/rfc9110#name-preconditions ETagMatch, - /// Like `ETagMatch`, but with support for `PutMode::Create` and not - /// `PutMode::Option`. - /// - /// This is the limited form of conditional put supported by Amazon S3 - /// as of August 2024 ([announcement]). - /// - /// Encoded as `etag-put-if-not-exists` ignoring whitespace. - /// - /// [announcement]: https://aws.amazon.com/about-aws/whats-new/2024/08/amazon-s3-conditional-writes/ - ETagPutIfNotExists, - /// The name of a DynamoDB table to use for coordination /// /// Encoded as either `dynamo:` or `dynamo::` @@ -164,7 +153,6 @@ impl std::fmt::Display for S3ConditionalPut { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::ETagMatch => write!(f, "etag"), - Self::ETagPutIfNotExists => write!(f, "etag-put-if-not-exists"), Self::Dynamo(lock) => write!(f, "dynamo: {}", lock.table_name()), } } @@ -174,7 +162,6 @@ impl S3ConditionalPut { fn from_str(s: &str) -> Option { match s.trim() { "etag" => Some(Self::ETagMatch), - "etag-put-if-not-exists" => Some(Self::ETagPutIfNotExists), trimmed => match trimmed.split_once(':')? { ("dynamo", s) => Some(Self::Dynamo(DynamoCommit::from_str(s)?)), _ => None, diff --git a/object_store/src/aws/resolve.rs b/object_store/src/aws/resolve.rs index 25bc74f32f29..db899ea989e3 100644 --- a/object_store/src/aws/resolve.rs +++ b/object_store/src/aws/resolve.rs @@ -17,21 +17,20 @@ use crate::aws::STORE; use crate::{ClientOptions, Result}; -use snafu::{ensure, OptionExt, ResultExt, Snafu}; /// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Bucket '{}' not found", bucket))] + #[error("Bucket '{}' not found", bucket)] BucketNotFound { bucket: String }, - #[snafu(display("Failed to resolve region for bucket '{}'", bucket))] + #[error("Failed to resolve region for bucket '{}'", bucket)] ResolveRegion { bucket: String, source: reqwest::Error, }, - #[snafu(display("Failed to parse the region for bucket '{}'", bucket))] + #[error("Failed to parse the region for bucket '{}'", bucket)] RegionParse { bucket: String }, } @@ -54,22 +53,23 @@ pub async fn resolve_bucket_region(bucket: &str, client_options: &ClientOptions) let client = client_options.client()?; - let response = client - .head(&endpoint) - .send() - .await - .context(ResolveRegionSnafu { bucket })?; + let response = client.head(&endpoint).send().await.map_err(|source| { + let bucket = bucket.into(); + Error::ResolveRegion { bucket, source } + })?; - ensure!( - response.status() != StatusCode::NOT_FOUND, - BucketNotFoundSnafu { bucket } - ); + if response.status() == StatusCode::NOT_FOUND { + let bucket = bucket.into(); + return Err(Error::BucketNotFound { bucket }.into()); + } let region = response .headers() .get("x-amz-bucket-region") .and_then(|x| x.to_str().ok()) - .context(RegionParseSnafu { bucket })?; + .ok_or_else(|| Error::RegionParse { + bucket: bucket.into(), + })?; Ok(region.to_string()) } diff --git a/object_store/src/azure/builder.rs b/object_store/src/azure/builder.rs index 1c4589ba1ec6..f0572ebe6358 100644 --- a/object_store/src/azure/builder.rs +++ b/object_store/src/azure/builder.rs @@ -26,7 +26,6 @@ use crate::config::ConfigValue; use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider}; use percent_encoding::percent_decode_str; use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; use std::str::FromStr; use std::sync::Arc; use url::Url; @@ -45,48 +44,48 @@ const EMULATOR_ACCOUNT_KEY: &str = const MSI_ENDPOINT_ENV_KEY: &str = "IDENTITY_ENDPOINT"; /// A specialized `Error` for Azure builder-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] UnableToParseUrl { source: url::ParseError, url: String, }, - #[snafu(display( + #[error( "Unable parse emulator url {}={}, Error: {}", env_name, env_value, source - ))] + )] UnableToParseEmulatorUrl { env_name: String, env_value: String, source: url::ParseError, }, - #[snafu(display("Account must be specified"))] + #[error("Account must be specified")] MissingAccount {}, - #[snafu(display("Container name must be specified"))] + #[error("Container name must be specified")] MissingContainerName {}, - #[snafu(display( + #[error( "Unknown url scheme cannot be parsed into storage location: {}", scheme - ))] + )] UnknownUrlScheme { scheme: String }, - #[snafu(display("URL did not match any known pattern for scheme: {}", url))] + #[error("URL did not match any known pattern for scheme: {}", url)] UrlNotRecognised { url: String }, - #[snafu(display("Failed parsing an SAS key"))] + #[error("Failed parsing an SAS key")] DecodeSasKey { source: std::str::Utf8Error }, - #[snafu(display("Missing component in SAS query pair"))] + #[error("Missing component in SAS query pair")] MissingSasComponent {}, - #[snafu(display("Configuration key: '{}' is not known.", key))] + #[error("Configuration key: '{}' is not known.", key)] UnknownConfigurationKey { key: String }, } @@ -240,6 +239,14 @@ pub enum AzureConfigKey { /// - `authority_id` AuthorityId, + /// Authority host used in oauth flows + /// + /// Supported keys: + /// - `azure_storage_authority_host` + /// - `azure_authority_host` + /// - `authority_host` + AuthorityHost, + /// Shared access signature. /// /// The signature is expected to be percent-encoded, much like they are provided @@ -383,6 +390,7 @@ impl AsRef for AzureConfigKey { Self::ClientId => "azure_storage_client_id", Self::ClientSecret => "azure_storage_client_secret", Self::AuthorityId => "azure_storage_tenant_id", + Self::AuthorityHost => "azure_storage_authority_host", Self::SasKey => "azure_storage_sas_key", Self::Token => "azure_storage_token", Self::UseEmulator => "azure_storage_use_emulator", @@ -427,6 +435,9 @@ impl FromStr for AzureConfigKey { | "azure_authority_id" | "tenant_id" | "authority_id" => Ok(Self::AuthorityId), + "azure_storage_authority_host" | "azure_authority_host" | "authority_host" => { + Ok(Self::AuthorityHost) + } "azure_storage_sas_key" | "azure_storage_sas_token" | "sas_key" | "sas_token" => { Ok(Self::SasKey) } @@ -556,6 +567,7 @@ impl MicrosoftAzureBuilder { AzureConfigKey::ClientId => self.client_id = Some(value.into()), AzureConfigKey::ClientSecret => self.client_secret = Some(value.into()), AzureConfigKey::AuthorityId => self.tenant_id = Some(value.into()), + AzureConfigKey::AuthorityHost => self.authority_host = Some(value.into()), AzureConfigKey::SasKey => self.sas_key = Some(value.into()), AzureConfigKey::Token => self.bearer_token = Some(value.into()), AzureConfigKey::MsiEndpoint => self.msi_endpoint = Some(value.into()), @@ -602,6 +614,7 @@ impl MicrosoftAzureBuilder { AzureConfigKey::ClientId => self.client_id.clone(), AzureConfigKey::ClientSecret => self.client_secret.clone(), AzureConfigKey::AuthorityId => self.tenant_id.clone(), + AzureConfigKey::AuthorityHost => self.authority_host.clone(), AzureConfigKey::SasKey => self.sas_key.clone(), AzureConfigKey::Token => self.bearer_token.clone(), AzureConfigKey::UseEmulator => Some(self.use_emulator.to_string()), @@ -628,11 +641,17 @@ impl MicrosoftAzureBuilder { /// This is a separate member function to allow fallible computation to /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] fn parse_url(&mut self, url: &str) -> Result<()> { - let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; - let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; + let parsed = Url::parse(url).map_err(|source| { + let url = url.into(); + Error::UnableToParseUrl { url, source } + })?; + + let host = parsed + .host_str() + .ok_or_else(|| Error::UrlNotRecognised { url: url.into() })?; let validate = |s: &str| match s.contains('.') { - true => Err(UrlNotRecognisedSnafu { url }.build()), + true => Err(Error::UrlNotRecognised { url: url.into() }), false => Ok(s.to_string()), }; @@ -651,7 +670,7 @@ impl MicrosoftAzureBuilder { self.account_name = Some(validate(a)?); self.use_fabric_endpoint = true.into(); } else { - return Err(UrlNotRecognisedSnafu { url }.build().into()); + return Err(Error::UrlNotRecognised { url: url.into() }.into()); } } "https" => match host.split_once('.') { @@ -675,9 +694,12 @@ impl MicrosoftAzureBuilder { } self.use_fabric_endpoint = true.into(); } - _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), + _ => return Err(Error::UrlNotRecognised { url: url.into() }.into()), }, - scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), + scheme => { + let scheme = scheme.into(); + return Err(Error::UnknownUrlScheme { scheme }.into()); + } } Ok(()) } @@ -910,8 +932,10 @@ impl MicrosoftAzureBuilder { }, }; - let url = - Url::parse(&account_url).context(UnableToParseUrlSnafu { url: account_url })?; + let url = Url::parse(&account_url).map_err(|source| { + let url = account_url.clone(); + Error::UnableToParseUrl { url, source } + })?; let credential = if let Some(credential) = self.credentials { credential @@ -1016,10 +1040,13 @@ impl MicrosoftAzureBuilder { /// if present, otherwise falls back to default_url fn url_from_env(env_name: &str, default_url: &str) -> Result { let url = match std::env::var(env_name) { - Ok(env_value) => Url::parse(&env_value).context(UnableToParseEmulatorUrlSnafu { - env_name, - env_value, - })?, + Ok(env_value) => { + Url::parse(&env_value).map_err(|source| Error::UnableToParseEmulatorUrl { + env_name: env_name.into(), + env_value, + source, + })? + } Err(_) => Url::parse(default_url).expect("Failed to parse default URL"), }; Ok(url) @@ -1028,7 +1055,7 @@ fn url_from_env(env_name: &str, default_url: &str) -> Result { fn split_sas(sas: &str) -> Result, Error> { let sas = percent_decode_str(sas) .decode_utf8() - .context(DecodeSasKeySnafu {})?; + .map_err(|source| Error::DecodeSasKey { source })?; let kv_str_pairs = sas .trim_start_matches('?') .split('&') diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index ba2632a5e503..fa5412c455fc 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -31,17 +31,17 @@ use crate::{ PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, RetryConfig, TagSet, }; use async_trait::async_trait; -use base64::prelude::BASE64_STANDARD; +use base64::prelude::{BASE64_STANDARD, BASE64_STANDARD_NO_PAD}; use base64::Engine; use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; use hyper::http::HeaderName; +use rand::Rng as _; use reqwest::{ - header::{HeaderValue, CONTENT_LENGTH, IF_MATCH, IF_NONE_MATCH}, + header::{HeaderMap, HeaderValue, CONTENT_LENGTH, CONTENT_TYPE, IF_MATCH, IF_NONE_MATCH}, Client as ReqwestClient, Method, RequestBuilder, Response, }; use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -59,56 +59,84 @@ static MS_CONTENT_LANGUAGE: HeaderName = HeaderName::from_static("x-ms-blob-cont static TAGS_HEADER: HeaderName = HeaderName::from_static("x-ms-tags"); /// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub(crate) enum Error { - #[snafu(display("Error performing get request {}: {}", path, source))] + #[error("Error performing get request {}: {}", path, source)] GetRequest { source: crate::client::retry::Error, path: String, }, - #[snafu(display("Error performing put request {}: {}", path, source))] + #[error("Error performing put request {}: {}", path, source)] PutRequest { source: crate::client::retry::Error, path: String, }, - #[snafu(display("Error performing delete request {}: {}", path, source))] + #[error("Error performing delete request {}: {}", path, source)] DeleteRequest { source: crate::client::retry::Error, path: String, }, - #[snafu(display("Error performing list request: {}", source))] + #[error("Error performing bulk delete request: {}", source)] + BulkDeleteRequest { source: crate::client::retry::Error }, + + #[error("Error receiving bulk delete request body: {}", source)] + BulkDeleteRequestBody { source: reqwest::Error }, + + #[error( + "Bulk delete request failed due to invalid input: {} (code: {})", + reason, + code + )] + BulkDeleteRequestInvalidInput { code: String, reason: String }, + + #[error("Got invalid bulk delete response: {}", reason)] + InvalidBulkDeleteResponse { reason: String }, + + #[error( + "Bulk delete request failed for key {}: {} (code: {})", + path, + reason, + code + )] + DeleteFailed { + path: String, + code: String, + reason: String, + }, + + #[error("Error performing list request: {}", source)] ListRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting list response body: {}", source))] + #[error("Error getting list response body: {}", source)] ListResponseBody { source: reqwest::Error }, - #[snafu(display("Got invalid list response: {}", source))] + #[error("Got invalid list response: {}", source)] InvalidListResponse { source: quick_xml::de::DeError }, - #[snafu(display("Unable to extract metadata from headers: {}", source))] + #[error("Unable to extract metadata from headers: {}", source)] Metadata { source: crate::client::header::Error, }, - #[snafu(display("ETag required for conditional update"))] + #[error("ETag required for conditional update")] MissingETag, - #[snafu(display("Error requesting user delegation key: {}", source))] + #[error("Error requesting user delegation key: {}", source)] DelegationKeyRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting user delegation key response body: {}", source))] + #[error("Error getting user delegation key response body: {}", source)] DelegationKeyResponseBody { source: reqwest::Error }, - #[snafu(display("Got invalid user delegation key response: {}", source))] + #[error("Got invalid user delegation key response: {}", source)] DelegationKeyResponse { source: quick_xml::de::DeError }, - #[snafu(display("Generating SAS keys with SAS tokens auth is not supported"))] + #[error("Generating SAS keys with SAS tokens auth is not supported")] SASforSASNotSupported, - #[snafu(display("Generating SAS keys while skipping signatures is not supported"))] + #[error("Generating SAS keys while skipping signatures is not supported")] SASwithSkipSignature, } @@ -170,7 +198,7 @@ struct PutRequest<'a> { idempotent: bool, } -impl<'a> PutRequest<'a> { +impl PutRequest<'_> { fn header(self, k: &HeaderName, v: &str) -> Self { let builder = self.builder.header(k, v); Self { builder, ..self } @@ -239,14 +267,232 @@ impl<'a> PutRequest<'a> { .payload(Some(self.payload)) .send() .await - .context(PutRequestSnafu { - path: self.path.as_ref(), + .map_err(|source| { + let path = self.path.as_ref().into(); + Error::PutRequest { path, source } })?; Ok(response) } } +#[inline] +fn extend(dst: &mut Vec, data: &[u8]) { + dst.extend_from_slice(data); +} + +// Write header names as title case. The header name is assumed to be ASCII. +// We need it because Azure is not always treating headers as case insensitive. +fn title_case(dst: &mut Vec, name: &[u8]) { + dst.reserve(name.len()); + + // Ensure first character is uppercased + let mut prev = b'-'; + for &(mut c) in name { + if prev == b'-' { + c.make_ascii_uppercase(); + } + dst.push(c); + prev = c; + } +} + +fn write_headers(headers: &HeaderMap, dst: &mut Vec) { + for (name, value) in headers { + // We need special case handling here otherwise Azure returns 400 + // due to `Content-Id` instead of `Content-ID` + if name == "content-id" { + extend(dst, b"Content-ID"); + } else { + title_case(dst, name.as_str().as_bytes()); + } + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +// https://docs.oasis-open.org/odata/odata/v4.0/errata02/os/complete/part1-protocol/odata-v4.0-errata02-os-part1-protocol-complete.html#_Toc406398359 +fn serialize_part_delete_request( + dst: &mut Vec, + boundary: &str, + idx: usize, + request: reqwest::Request, + relative_url: String, +) { + // Encode start marker for part + extend(dst, b"--"); + extend(dst, boundary.as_bytes()); + extend(dst, b"\r\n"); + + // Encode part headers + let mut part_headers = HeaderMap::new(); + part_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/http")); + part_headers.insert( + "Content-Transfer-Encoding", + HeaderValue::from_static("binary"), + ); + // Azure returns 400 if we send `Content-Id` instead of `Content-ID` + part_headers.insert("Content-ID", HeaderValue::from(idx)); + write_headers(&part_headers, dst); + extend(dst, b"\r\n"); + + // Encode the subrequest request-line + extend(dst, b"DELETE "); + extend(dst, format!("/{} ", relative_url).as_bytes()); + extend(dst, b"HTTP/1.1"); + extend(dst, b"\r\n"); + + // Encode subrequest headers + write_headers(request.headers(), dst); + extend(dst, b"\r\n"); + extend(dst, b"\r\n"); +} + +fn parse_multipart_response_boundary(response: &Response) -> Result { + let invalid_response = |msg: &str| Error::InvalidBulkDeleteResponse { + reason: msg.to_string(), + }; + + let content_type = response + .headers() + .get(CONTENT_TYPE) + .ok_or_else(|| invalid_response("missing Content-Type"))?; + + let boundary = content_type + .as_ref() + .strip_prefix(b"multipart/mixed; boundary=") + .ok_or_else(|| invalid_response("invalid Content-Type value"))? + .to_vec(); + + let boundary = + String::from_utf8(boundary).map_err(|_| invalid_response("invalid multipart boundary"))?; + + Ok(boundary) +} + +fn invalid_response(msg: &str) -> Error { + Error::InvalidBulkDeleteResponse { + reason: msg.to_string(), + } +} + +#[derive(Debug)] +struct MultipartField { + headers: HeaderMap, + content: Bytes, +} + +fn parse_multipart_body_fields(body: Bytes, boundary: &[u8]) -> Result> { + let start_marker = [b"--", boundary, b"\r\n"].concat(); + let next_marker = &start_marker[..start_marker.len() - 2]; + let end_marker = [b"--", boundary, b"--\r\n"].concat(); + + // There should be at most 256 responses per batch + let mut fields = Vec::with_capacity(256); + let mut remaining: &[u8] = body.as_ref(); + loop { + remaining = remaining + .strip_prefix(start_marker.as_slice()) + .ok_or_else(|| invalid_response("missing start marker for field"))?; + + // The documentation only mentions two headers for fields, we leave some extra margin + let mut scratch = [httparse::EMPTY_HEADER; 10]; + let mut headers = HeaderMap::new(); + match httparse::parse_headers(remaining, &mut scratch) { + Ok(httparse::Status::Complete((pos, headers_slice))) => { + remaining = &remaining[pos..]; + for header in headers_slice { + headers.insert( + HeaderName::from_bytes(header.name.as_bytes()).expect("valid"), + HeaderValue::from_bytes(header.value).expect("valid"), + ); + } + } + _ => return Err(invalid_response("unable to parse field headers").into()), + }; + + let next_pos = remaining + .windows(next_marker.len()) + .position(|window| window == next_marker) + .ok_or_else(|| invalid_response("early EOF while seeking to next boundary"))?; + + fields.push(MultipartField { + headers, + content: body.slice_ref(&remaining[..next_pos]), + }); + + remaining = &remaining[next_pos..]; + + // Support missing final CRLF + if remaining == end_marker || remaining == &end_marker[..end_marker.len() - 2] { + break; + } + } + Ok(fields) +} + +async fn parse_blob_batch_delete_body( + batch_body: Bytes, + boundary: String, + paths: &[Path], +) -> Result>> { + let mut results: Vec> = paths.iter().cloned().map(Ok).collect(); + + for field in parse_multipart_body_fields(batch_body, boundary.as_bytes())? { + let id = field + .headers + .get("content-id") + .and_then(|v| std::str::from_utf8(v.as_bytes()).ok()) + .and_then(|v| v.parse::().ok()); + + // Parse part response headers + // Documentation mentions 5 headers and states that other standard HTTP headers + // may be provided, in order to not incurr in more complexity to support an arbitrary + // amount of headers we chose a conservative amount and error otherwise + // https://learn.microsoft.com/en-us/rest/api/storageservices/delete-blob?tabs=microsoft-entra-id#response-headers + let mut headers = [httparse::EMPTY_HEADER; 48]; + let mut part_response = httparse::Response::new(&mut headers); + match part_response.parse(&field.content) { + Ok(httparse::Status::Complete(_)) => {} + _ => return Err(invalid_response("unable to parse response").into()), + }; + + match (id, part_response.code) { + (Some(_id), Some(code)) if (200..300).contains(&code) => {} + (Some(id), Some(404)) => { + results[id] = Err(crate::Error::NotFound { + path: paths[id].as_ref().to_string(), + source: Error::DeleteFailed { + path: paths[id].as_ref().to_string(), + code: 404.to_string(), + reason: part_response.reason.unwrap_or_default().to_string(), + } + .into(), + }); + } + (Some(id), Some(code)) => { + results[id] = Err(Error::DeleteFailed { + path: paths[id].as_ref().to_string(), + code: code.to_string(), + reason: part_response.reason.unwrap_or_default().to_string(), + } + .into()); + } + (None, Some(code)) => { + return Err(Error::BulkDeleteRequestInvalidInput { + code: code.to_string(), + reason: part_response.reason.unwrap_or_default().to_string(), + } + .into()) + } + _ => return Err(invalid_response("missing part response status code").into()), + } + } + + Ok(results) +} + #[derive(Debug)] pub(crate) struct AzureClient { config: AzureConfig, @@ -298,23 +544,25 @@ impl AzureClient { PutMode::Overwrite => builder.idempotent(true), PutMode::Create => builder.header(&IF_NONE_MATCH, "*"), PutMode::Update(v) => { - let etag = v.e_tag.as_ref().context(MissingETagSnafu)?; + let etag = v.e_tag.as_ref().ok_or(Error::MissingETag)?; builder.header(&IF_MATCH, etag) } }; let response = builder.header(&BLOB_TYPE, "BlockBlob").send().await?; - Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) } /// PUT a block pub(crate) async fn put_block( &self, path: &Path, - part_idx: usize, + _part_idx: usize, payload: PutPayload, ) -> Result { - let content_id = format!("{part_idx:20}"); + let part_idx = u128::from_be_bytes(rand::thread_rng().gen()); + let content_id = format!("{part_idx:032x}"); let block_id = BASE64_STANDARD.encode(&content_id); self.put_request(path, payload) @@ -348,7 +596,8 @@ impl AzureClient { .send() .await?; - Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) } /// Make an Azure Delete request @@ -373,13 +622,94 @@ impl AzureClient { .sensitive(sensitive) .send() .await - .context(DeleteRequestSnafu { - path: path.as_ref(), + .map_err(|source| { + let path = path.as_ref().into(); + Error::DeleteRequest { source, path } })?; Ok(()) } + fn build_bulk_delete_body( + &self, + boundary: &str, + paths: &[Path], + credential: &Option>, + ) -> Vec { + let mut body_bytes = Vec::with_capacity(paths.len() * 2048); + + for (idx, path) in paths.iter().enumerate() { + let url = self.config.path_url(path); + + // Build subrequest with proper authorization + let request = self + .client + .request(Method::DELETE, url) + .header(CONTENT_LENGTH, HeaderValue::from(0)) + // Each subrequest must be authorized individually [1] and we use + // the CredentialExt for this. + // [1]: https://learn.microsoft.com/en-us/rest/api/storageservices/blob-batch?tabs=microsoft-entra-id#request-body + .with_azure_authorization(credential, &self.config.account) + .build() + .unwrap(); + + // Url for part requests must be relative and without base + let relative_url = self.config.service.make_relative(request.url()).unwrap(); + + serialize_part_delete_request(&mut body_bytes, boundary, idx, request, relative_url) + } + + // Encode end marker + extend(&mut body_bytes, b"--"); + extend(&mut body_bytes, boundary.as_bytes()); + extend(&mut body_bytes, b"--"); + extend(&mut body_bytes, b"\r\n"); + body_bytes + } + + pub(crate) async fn bulk_delete_request(&self, paths: Vec) -> Result>> { + if paths.is_empty() { + return Ok(Vec::new()); + } + + let credential = self.get_credential().await?; + + // https://www.ietf.org/rfc/rfc2046 + let random_bytes = rand::random::<[u8; 16]>(); // 128 bits + let boundary = format!("batch_{}", BASE64_STANDARD_NO_PAD.encode(random_bytes)); + + let body_bytes = self.build_bulk_delete_body(&boundary, &paths, &credential); + + // Send multipart request + let url = self.config.path_url(&Path::from("/")); + let batch_response = self + .client + .request(Method::POST, url) + .query(&[("restype", "container"), ("comp", "batch")]) + .header( + CONTENT_TYPE, + HeaderValue::from_str(format!("multipart/mixed; boundary={}", boundary).as_str()) + .unwrap(), + ) + .header(CONTENT_LENGTH, HeaderValue::from(body_bytes.len())) + .body(body_bytes) + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| Error::BulkDeleteRequest { source })?; + + let boundary = parse_multipart_response_boundary(&batch_response)?; + + let batch_body = batch_response + .bytes() + .await + .map_err(|source| Error::BulkDeleteRequestBody { source })?; + + let results = parse_blob_batch_delete_body(batch_body, boundary, &paths).await?; + + Ok(results) + } + /// Make an Azure Copy request pub(crate) async fn copy_request(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> { let credential = self.get_credential().await?; @@ -453,13 +783,13 @@ impl AzureClient { .idempotent(true) .send() .await - .context(DelegationKeyRequestSnafu)? + .map_err(|source| Error::DelegationKeyRequest { source })? .bytes() .await - .context(DelegationKeyResponseBodySnafu)?; + .map_err(|source| Error::DelegationKeyResponseBody { source })?; - let response: UserDelegationKey = - quick_xml::de::from_reader(response.reader()).context(DelegationKeyResponseSnafu)?; + let response: UserDelegationKey = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::DelegationKeyResponse { source })?; Ok(response) } @@ -515,9 +845,11 @@ impl AzureClient { .sensitive(sensitive) .send() .await - .context(GetRequestSnafu { - path: path.as_ref(), + .map_err(|source| { + let path = path.as_ref().into(); + Error::GetRequest { source, path } })?; + Ok(response) } } @@ -573,8 +905,9 @@ impl GetClient for AzureClient { .sensitive(sensitive) .send() .await - .context(GetRequestSnafu { - path: path.as_ref(), + .map_err(|source| { + let path = path.as_ref().into(); + Error::GetRequest { source, path } })?; match response.headers().get("x-ms-resource-type") { @@ -635,13 +968,14 @@ impl ListClient for Arc { .sensitive(sensitive) .send() .await - .context(ListRequestSnafu)? + .map_err(|source| Error::ListRequest { source })? .bytes() .await - .context(ListResponseBodySnafu)?; + .map_err(|source| Error::ListResponseBody { source })?; + + let mut response: ListResultInternal = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidListResponse { source })?; - let mut response: ListResultInternal = - quick_xml::de::from_reader(response.reader()).context(InvalidListResponseSnafu)?; let token = response.next_marker.take(); Ok((to_list_result(response, prefix)?, token)) @@ -814,8 +1148,10 @@ pub(crate) struct UserDelegationKey { #[cfg(test)] mod tests { use bytes::Bytes; + use regex::bytes::Regex; use super::*; + use crate::StaticCredentialProvider; #[test] fn deserde_azure() { @@ -1005,4 +1341,159 @@ mod tests { let _delegated_key_response_internal: UserDelegationKey = quick_xml::de::from_str(S).unwrap(); } + + #[tokio::test] + async fn test_build_bulk_delete_body() { + let credential_provider = Arc::new(StaticCredentialProvider::new( + AzureCredential::BearerToken("static-token".to_string()), + )); + + let config = AzureConfig { + account: "testaccount".to_string(), + container: "testcontainer".to_string(), + credentials: credential_provider, + service: "http://example.com".try_into().unwrap(), + retry_config: Default::default(), + is_emulator: false, + skip_signature: false, + disable_tagging: false, + client_options: Default::default(), + }; + + let client = AzureClient::new(config).unwrap(); + + let credential = client.get_credential().await.unwrap(); + let paths = &[Path::from("a"), Path::from("b"), Path::from("c")]; + + let boundary = "batch_statictestboundary".to_string(); + + let body_bytes = client.build_bulk_delete_body(&boundary, paths, &credential); + + // Replace Date header value with a static date + let re = Regex::new("Date:[^\r]+").unwrap(); + let body_bytes = re + .replace_all(&body_bytes, b"Date: Tue, 05 Nov 2024 15:01:15 GMT") + .to_vec(); + + let expected_body = b"--batch_statictestboundary\r +Content-Type: application/http\r +Content-Transfer-Encoding: binary\r +Content-ID: 0\r +\r +DELETE /testcontainer/a HTTP/1.1\r +Content-Length: 0\r +Date: Tue, 05 Nov 2024 15:01:15 GMT\r +X-Ms-Version: 2023-11-03\r +Authorization: Bearer static-token\r +\r +\r +--batch_statictestboundary\r +Content-Type: application/http\r +Content-Transfer-Encoding: binary\r +Content-ID: 1\r +\r +DELETE /testcontainer/b HTTP/1.1\r +Content-Length: 0\r +Date: Tue, 05 Nov 2024 15:01:15 GMT\r +X-Ms-Version: 2023-11-03\r +Authorization: Bearer static-token\r +\r +\r +--batch_statictestboundary\r +Content-Type: application/http\r +Content-Transfer-Encoding: binary\r +Content-ID: 2\r +\r +DELETE /testcontainer/c HTTP/1.1\r +Content-Length: 0\r +Date: Tue, 05 Nov 2024 15:01:15 GMT\r +X-Ms-Version: 2023-11-03\r +Authorization: Bearer static-token\r +\r +\r +--batch_statictestboundary--\r\n" + .to_vec(); + + assert_eq!(expected_body, body_bytes); + } + + #[tokio::test] + async fn test_parse_blob_batch_delete_body() { + let response_body = b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r +Content-Type: application/http\r +Content-ID: 0\r +\r +HTTP/1.1 202 Accepted\r +x-ms-delete-type-permanent: true\r +x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r +x-ms-version: 2018-11-09\r +\r +--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r +Content-Type: application/http\r +Content-ID: 1\r +\r +HTTP/1.1 202 Accepted\r +x-ms-delete-type-permanent: true\r +x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2851\r +x-ms-version: 2018-11-09\r +\r +--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r +Content-Type: application/http\r +Content-ID: 2\r +\r +HTTP/1.1 404 The specified blob does not exist.\r +x-ms-error-code: BlobNotFound\r +x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852\r +x-ms-version: 2018-11-09\r +Content-Length: 216\r +Content-Type: application/xml\r +\r + +BlobNotFoundThe specified blob does not exist. +RequestId:778fdc83-801e-0000-62ff-0334671e2852 +Time:2018-06-14T16:46:54.6040685Z\r +--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n"; + + let response: reqwest::Response = http::Response::builder() + .status(202) + .header("Transfer-Encoding", "chunked") + .header( + "Content-Type", + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed", + ) + .header("x-ms-request-id", "778fdc83-801e-0000-62ff-033467000000") + .header("x-ms-version", "2018-11-09") + .body(Bytes::from(response_body.as_slice())) + .unwrap() + .into(); + + let boundary = parse_multipart_response_boundary(&response).unwrap(); + let body = response.bytes().await.unwrap(); + + let paths = &[Path::from("a"), Path::from("b"), Path::from("c")]; + + let results = parse_blob_batch_delete_body(body, boundary, paths) + .await + .unwrap(); + + assert!(results[0].is_ok()); + assert_eq!(&paths[0], results[0].as_ref().unwrap()); + + assert!(results[1].is_ok()); + assert_eq!(&paths[1], results[1].as_ref().unwrap()); + + assert!(results[2].is_err()); + let err = results[2].as_ref().unwrap_err(); + let crate::Error::NotFound { source, .. } = err else { + unreachable!("must be not found") + }; + let Some(Error::DeleteFailed { path, code, reason }) = source.downcast_ref::() + else { + unreachable!("must be client error") + }; + + assert_eq!(paths[2].as_ref(), path); + assert_eq!("404", code); + assert_eq!("The specified blob does not exist.", reason); + } } diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 2832eed72256..c9e6ac640b4a 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -32,7 +32,6 @@ use reqwest::header::{ }; use reqwest::{Client, Method, Request, RequestBuilder}; use serde::Deserialize; -use snafu::{ResultExt, Snafu}; use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Debug; @@ -71,27 +70,27 @@ const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default"; /// const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com"; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub enum Error { - #[snafu(display("Error performing token request: {}", source))] + #[error("Error performing token request: {}", source)] TokenRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting token response body: {}", source))] + #[error("Error getting token response body: {}", source)] TokenResponseBody { source: reqwest::Error }, - #[snafu(display("Error reading federated token file "))] + #[error("Error reading federated token file ")] FederatedTokenFile, - #[snafu(display("Invalid Access Key: {}", source))] + #[error("Invalid Access Key: {}", source)] InvalidAccessKey { source: base64::DecodeError }, - #[snafu(display("'az account get-access-token' command failed: {message}"))] + #[error("'az account get-access-token' command failed: {message}")] AzureCli { message: String }, - #[snafu(display("Failed to parse azure cli response: {source}"))] + #[error("Failed to parse azure cli response: {source}")] AzureCliResponse { source: serde_json::Error }, - #[snafu(display("Generating SAS keys with SAS tokens auth is not supported"))] + #[error("Generating SAS keys with SAS tokens auth is not supported")] SASforSASNotSupported, } @@ -113,7 +112,10 @@ pub struct AzureAccessKey(Vec); impl AzureAccessKey { /// Create a new [`AzureAccessKey`], checking it for validity pub fn try_new(key: &str) -> Result { - let key = BASE64_STANDARD.decode(key).context(InvalidAccessKeySnafu)?; + let key = BASE64_STANDARD + .decode(key) + .map_err(|source| Error::InvalidAccessKey { source })?; + Ok(Self(key)) } } @@ -636,10 +638,10 @@ impl TokenProvider for ClientSecretOAuthProvider { .idempotent(true) .send() .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .json() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(response.access_token)), @@ -744,10 +746,10 @@ impl TokenProvider for ImdsManagedIdentityProvider { let response: ImdsTokenResponse = builder .send_retry(retry) .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .json() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(response.access_token)), @@ -820,10 +822,10 @@ impl TokenProvider for WorkloadIdentityOAuthProvider { .idempotent(true) .send() .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .json() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(response.access_token)), @@ -900,7 +902,8 @@ impl AzureCliCredential { })?; let token_response = serde_json::from_str::(output) - .context(AzureCliResponseSnafu)?; + .map_err(|source| Error::AzureCliResponse { source })?; + if !token_response.token_type.eq_ignore_ascii_case("bearer") { return Err(Error::AzureCli { message: format!( @@ -1033,10 +1036,10 @@ impl TokenProvider for FabricTokenOAuthProvider { .idempotent(true) .send() .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .text() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; let exp_in = Self::validate_and_get_expiry(&access_token) .map_or(3600, |expiry| expiry - Self::get_current_timestamp()); Ok(TemporaryToken { diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 9e7fc1738dc6..ea4dd8f567a9 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -30,7 +30,7 @@ use crate::{ PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, UploadPart, }; use async_trait::async_trait; -use futures::stream::BoxStream; +use futures::stream::{BoxStream, StreamExt, TryStreamExt}; use reqwest::Method; use std::fmt::Debug; use std::sync::Arc; @@ -122,6 +122,25 @@ impl ObjectStore for MicrosoftAzure { fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { self.client.list(prefix) } + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + locations + .try_chunks(256) + .map(move |locations| async { + // Early return the error. We ignore the paths that have already been + // collected into the chunk. + let locations = locations.map_err(|e| e.1)?; + self.client + .bulk_delete_request(locations) + .await + .map(futures::stream::iter) + }) + .buffered(20) + .try_flatten() + .boxed() + } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { self.client.list_with_delimiter(prefix).await @@ -294,6 +313,7 @@ mod tests { stream_get(&integration).await; put_opts(&integration, true).await; multipart(&integration, &integration).await; + multipart_race_condition(&integration, false).await; signing(&integration).await; let validate = !integration.client.config().disable_tagging; diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index f1ec591470f7..4998e9f2a04d 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -86,6 +86,7 @@ impl ObjectStore for ChunkedStore { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { let r = self.inner.get_opts(location, options).await?; let stream = match r.payload { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] GetResultPayload::File(file, path) => { crate::local::chunked_stream(file, path, r.range.clone(), self.chunk_size) } @@ -178,7 +179,9 @@ impl ObjectStore for ChunkedStore { mod tests { use futures::StreamExt; + #[cfg(feature = "fs")] use crate::integration::*; + #[cfg(feature = "fs")] use crate::local::LocalFileSystem; use crate::memory::InMemory; use crate::path::Path; @@ -209,6 +212,7 @@ mod tests { } } + #[cfg(feature = "fs")] #[tokio::test] async fn test_chunked() { let temporary = tempfile::tempdir().unwrap(); diff --git a/object_store/src/client/get.rs b/object_store/src/client/get.rs index 5dd62cbece5a..57aca8956452 100644 --- a/object_store/src/client/get.rs +++ b/object_store/src/client/get.rs @@ -29,7 +29,6 @@ use hyper::header::{ use hyper::StatusCode; use reqwest::header::ToStrError; use reqwest::Response; -use snafu::{ensure, OptionExt, ResultExt, Snafu}; /// A client that can perform a get request #[async_trait] @@ -95,49 +94,51 @@ impl ContentRange { } /// A specialized `Error` for get-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum GetResultError { - #[snafu(context(false))] + #[error(transparent)] Header { + #[from] source: crate::client::header::Error, }, - #[snafu(transparent)] + #[error(transparent)] InvalidRangeRequest { + #[from] source: crate::util::InvalidGetRange, }, - #[snafu(display("Received non-partial response when range requested"))] + #[error("Received non-partial response when range requested")] NotPartial, - #[snafu(display("Content-Range header not present in partial response"))] + #[error("Content-Range header not present in partial response")] NoContentRange, - #[snafu(display("Failed to parse value for CONTENT_RANGE header: \"{value}\""))] + #[error("Failed to parse value for CONTENT_RANGE header: \"{value}\"")] ParseContentRange { value: String }, - #[snafu(display("Content-Range header contained non UTF-8 characters"))] + #[error("Content-Range header contained non UTF-8 characters")] InvalidContentRange { source: ToStrError }, - #[snafu(display("Cache-Control header contained non UTF-8 characters"))] + #[error("Cache-Control header contained non UTF-8 characters")] InvalidCacheControl { source: ToStrError }, - #[snafu(display("Content-Disposition header contained non UTF-8 characters"))] + #[error("Content-Disposition header contained non UTF-8 characters")] InvalidContentDisposition { source: ToStrError }, - #[snafu(display("Content-Encoding header contained non UTF-8 characters"))] + #[error("Content-Encoding header contained non UTF-8 characters")] InvalidContentEncoding { source: ToStrError }, - #[snafu(display("Content-Language header contained non UTF-8 characters"))] + #[error("Content-Language header contained non UTF-8 characters")] InvalidContentLanguage { source: ToStrError }, - #[snafu(display("Content-Type header contained non UTF-8 characters"))] + #[error("Content-Type header contained non UTF-8 characters")] InvalidContentType { source: ToStrError }, - #[snafu(display("Metadata value for \"{key:?}\" contained non UTF-8 characters"))] + #[error("Metadata value for \"{key:?}\" contained non UTF-8 characters")] InvalidMetadata { key: String }, - #[snafu(display("Requested {expected:?}, got {actual:?}"))] + #[error("Requested {expected:?}, got {actual:?}")] UnexpectedRange { expected: Range, actual: Range, @@ -153,17 +154,24 @@ fn get_result( // ensure that we receive the range we asked for let range = if let Some(expected) = range { - ensure!( - response.status() == StatusCode::PARTIAL_CONTENT, - NotPartialSnafu - ); + if response.status() != StatusCode::PARTIAL_CONTENT { + return Err(GetResultError::NotPartial); + } + let val = response .headers() .get(CONTENT_RANGE) - .context(NoContentRangeSnafu)?; + .ok_or(GetResultError::NoContentRange)?; + + let value = val + .to_str() + .map_err(|source| GetResultError::InvalidContentRange { source })?; + + let value = ContentRange::from_str(value).ok_or_else(|| { + let value = value.into(); + GetResultError::ParseContentRange { value } + })?; - let value = val.to_str().context(InvalidContentRangeSnafu)?; - let value = ContentRange::from_str(value).context(ParseContentRangeSnafu { value })?; let actual = value.range; // Update size to reflect full size of object (#5272) @@ -171,10 +179,9 @@ fn get_result( let expected = expected.as_range(meta.size)?; - ensure!( - actual == expected, - UnexpectedRangeSnafu { expected, actual } - ); + if actual != expected { + return Err(GetResultError::UnexpectedRange { expected, actual }); + } actual } else { @@ -182,11 +189,11 @@ fn get_result( }; macro_rules! parse_attributes { - ($headers:expr, $(($header:expr, $attr:expr, $err:expr)),*) => {{ + ($headers:expr, $(($header:expr, $attr:expr, $map_err:expr)),*) => {{ let mut attributes = Attributes::new(); $( if let Some(x) = $headers.get($header) { - let x = x.to_str().context($err)?; + let x = x.to_str().map_err($map_err)?; attributes.insert($attr, x.to_string().into()); } )* @@ -196,31 +203,23 @@ fn get_result( let mut attributes = parse_attributes!( response.headers(), - ( - CACHE_CONTROL, - Attribute::CacheControl, - InvalidCacheControlSnafu - ), + (CACHE_CONTROL, Attribute::CacheControl, |source| { + GetResultError::InvalidCacheControl { source } + }), ( CONTENT_DISPOSITION, Attribute::ContentDisposition, - InvalidContentDispositionSnafu - ), - ( - CONTENT_ENCODING, - Attribute::ContentEncoding, - InvalidContentEncodingSnafu + |source| GetResultError::InvalidContentDisposition { source } ), - ( - CONTENT_LANGUAGE, - Attribute::ContentLanguage, - InvalidContentLanguageSnafu - ), - ( - CONTENT_TYPE, - Attribute::ContentType, - InvalidContentTypeSnafu - ) + (CONTENT_ENCODING, Attribute::ContentEncoding, |source| { + GetResultError::InvalidContentEncoding { source } + }), + (CONTENT_LANGUAGE, Attribute::ContentLanguage, |source| { + GetResultError::InvalidContentLanguage { source } + }), + (CONTENT_TYPE, Attribute::ContentType, |source| { + GetResultError::InvalidContentType { source } + }) ); // Add attributes that match the user-defined metadata prefix (e.g. x-amz-meta-) diff --git a/object_store/src/client/header.rs b/object_store/src/client/header.rs index 07c04c11945a..db06da6345d5 100644 --- a/object_store/src/client/header.rs +++ b/object_store/src/client/header.rs @@ -22,7 +22,6 @@ use crate::ObjectMeta; use chrono::{DateTime, TimeZone, Utc}; use hyper::header::{CONTENT_LENGTH, ETAG, LAST_MODIFIED}; use hyper::HeaderMap; -use snafu::{OptionExt, ResultExt, Snafu}; #[derive(Debug, Copy, Clone)] /// Configuration for header extraction @@ -44,27 +43,27 @@ pub(crate) struct HeaderConfig { pub user_defined_metadata_prefix: Option<&'static str>, } -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub(crate) enum Error { - #[snafu(display("ETag Header missing from response"))] + #[error("ETag Header missing from response")] MissingEtag, - #[snafu(display("Received header containing non-ASCII data"))] + #[error("Received header containing non-ASCII data")] BadHeader { source: reqwest::header::ToStrError }, - #[snafu(display("Last-Modified Header missing from response"))] + #[error("Last-Modified Header missing from response")] MissingLastModified, - #[snafu(display("Content-Length Header missing from response"))] + #[error("Content-Length Header missing from response")] MissingContentLength, - #[snafu(display("Invalid last modified '{}': {}", last_modified, source))] + #[error("Invalid last modified '{}': {}", last_modified, source)] InvalidLastModified { last_modified: String, source: chrono::ParseError, }, - #[snafu(display("Invalid content length '{}': {}", content_length, source))] + #[error("Invalid content length '{}': {}", content_length, source)] InvalidContentLength { content_length: String, source: std::num::ParseIntError, @@ -86,7 +85,11 @@ pub(crate) fn get_put_result( #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] pub(crate) fn get_version(headers: &HeaderMap, version: &str) -> Result, Error> { Ok(match headers.get(version) { - Some(x) => Some(x.to_str().context(BadHeaderSnafu)?.to_string()), + Some(x) => Some( + x.to_str() + .map_err(|source| Error::BadHeader { source })? + .to_string(), + ), None => None, }) } @@ -94,7 +97,10 @@ pub(crate) fn get_version(headers: &HeaderMap, version: &str) -> Result Result { let e_tag = headers.get(ETAG).ok_or(Error::MissingEtag)?; - Ok(e_tag.to_str().context(BadHeaderSnafu)?.to_string()) + Ok(e_tag + .to_str() + .map_err(|source| Error::BadHeader { source })? + .to_string()) } /// Extracts [`ObjectMeta`] from the provided [`HeaderMap`] @@ -105,9 +111,15 @@ pub(crate) fn header_meta( ) -> Result { let last_modified = match headers.get(LAST_MODIFIED) { Some(last_modified) => { - let last_modified = last_modified.to_str().context(BadHeaderSnafu)?; + let last_modified = last_modified + .to_str() + .map_err(|source| Error::BadHeader { source })?; + DateTime::parse_from_rfc2822(last_modified) - .context(InvalidLastModifiedSnafu { last_modified })? + .map_err(|source| Error::InvalidLastModified { + last_modified: last_modified.into(), + source, + })? .with_timezone(&Utc) } None if cfg.last_modified_required => return Err(Error::MissingLastModified), @@ -122,15 +134,25 @@ pub(crate) fn header_meta( let content_length = headers .get(CONTENT_LENGTH) - .context(MissingContentLengthSnafu)?; + .ok_or(Error::MissingContentLength)?; + + let content_length = content_length + .to_str() + .map_err(|source| Error::BadHeader { source })?; - let content_length = content_length.to_str().context(BadHeaderSnafu)?; let size = content_length .parse() - .context(InvalidContentLengthSnafu { content_length })?; + .map_err(|source| Error::InvalidContentLength { + content_length: content_length.into(), + source, + })?; let version = match cfg.version_header.and_then(|h| headers.get(h)) { - Some(v) => Some(v.to_str().context(BadHeaderSnafu)?.to_string()), + Some(v) => Some( + v.to_str() + .map_err(|source| Error::BadHeader { source })? + .to_string(), + ), None => None, }; diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index 76d1c1f22f58..1b7ce5aa7a78 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -671,6 +671,10 @@ impl ClientOptions { builder = builder.danger_accept_invalid_certs(true) } + // Reqwest will remove the `Content-Length` header if it is configured to + // transparently decompress the body via the non-default `gzip` feature. + builder = builder.no_gzip(); + builder .https_only(!self.allow_http.get()?) .build() diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index 601bffdec158..8938b0861cca 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -22,30 +22,29 @@ use crate::PutPayload; use futures::future::BoxFuture; use reqwest::header::LOCATION; use reqwest::{Client, Request, Response, StatusCode}; -use snafu::Error as SnafuError; -use snafu::Snafu; +use std::error::Error as StdError; use std::time::{Duration, Instant}; use tracing::info; /// Retry request error -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub enum Error { - #[snafu(display("Received redirect without LOCATION, this normally indicates an incorrectly configured region"))] + #[error("Received redirect without LOCATION, this normally indicates an incorrectly configured region")] BareRedirect, - #[snafu(display("Server error, body contains Error, with status {status}: {}", body.as_deref().unwrap_or("No Body")))] + #[error("Server error, body contains Error, with status {status}: {}", body.as_deref().unwrap_or("No Body"))] Server { status: StatusCode, body: Option, }, - #[snafu(display("Client error with status {status}: {}", body.as_deref().unwrap_or("No Body")))] + #[error("Client error with status {status}: {}", body.as_deref().unwrap_or("No Body"))] Client { status: StatusCode, body: Option, }, - #[snafu(display("Error after {retries} retries in {elapsed:?}, max_retries:{max_retries}, retry_timeout:{retry_timeout:?}, source:{source}"))] + #[error("Error after {retries} retries in {elapsed:?}, max_retries:{max_retries}, retry_timeout:{retry_timeout:?}, source:{source}")] Reqwest { retries: usize, max_retries: usize, @@ -200,6 +199,7 @@ pub(crate) struct RetryableRequest { sensitive: bool, idempotent: Option, + retry_on_conflict: bool, payload: Option, retry_error_body: bool, @@ -217,6 +217,15 @@ impl RetryableRequest { } } + /// Set whether this request should be retried on a 409 Conflict response. + #[cfg(feature = "aws")] + pub(crate) fn retry_on_conflict(self, retry_on_conflict: bool) -> Self { + Self { + retry_on_conflict, + ..self + } + } + /// Set whether this request contains sensitive data /// /// This will avoid printing out the URL in error messages @@ -340,7 +349,8 @@ impl RetryableRequest { let status = r.status(); if retries == max_retries || now.elapsed() > retry_timeout - || !status.is_server_error() + || !(status.is_server_error() + || (self.retry_on_conflict && status == StatusCode::CONFLICT)) { return Err(match status.is_client_error() { true => match r.text().await { @@ -467,6 +477,7 @@ impl RetryExt for reqwest::RequestBuilder { idempotent: None, payload: None, sensitive: false, + retry_on_conflict: false, retry_error_body: false, } } diff --git a/object_store/src/client/s3.rs b/object_store/src/client/s3.rs index dba752cb1251..7fe956b2376e 100644 --- a/object_store/src/client/s3.rs +++ b/object_store/src/client/s3.rs @@ -106,14 +106,32 @@ pub(crate) struct CompleteMultipartUpload { pub part: Vec, } +#[derive(Serialize, Deserialize)] +pub(crate) struct PartMetadata { + pub e_tag: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub checksum_sha256: Option, +} + impl From> for CompleteMultipartUpload { fn from(value: Vec) -> Self { let part = value .into_iter() .enumerate() - .map(|(part_number, part)| MultipartPart { - e_tag: part.content_id, - part_number: part_number + 1, + .map(|(part_idx, part)| { + let md = match quick_xml::de::from_str::(&part.content_id) { + Ok(md) => md, + // fallback to old way + Err(_) => PartMetadata { + e_tag: part.content_id.clone(), + checksum_sha256: None, + }, + }; + MultipartPart { + e_tag: md.e_tag, + part_number: part_idx + 1, + checksum_sha256: md.checksum_sha256, + } }) .collect(); Self { part } @@ -126,6 +144,9 @@ pub(crate) struct MultipartPart { pub e_tag: String, #[serde(rename = "PartNumber")] pub part_number: usize, + #[serde(rename = "ChecksumSHA256")] + #[serde(skip_serializing_if = "Option::is_none")] + pub checksum_sha256: Option, } #[derive(Debug, Deserialize)] diff --git a/object_store/src/delimited.rs b/object_store/src/delimited.rs index 96f88bf41ff7..5b11a0bf7eb1 100644 --- a/object_store/src/delimited.rs +++ b/object_store/src/delimited.rs @@ -21,16 +21,15 @@ use std::collections::VecDeque; use bytes::Bytes; use futures::{Stream, StreamExt}; -use snafu::{ensure, Snafu}; use super::Result; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("encountered unterminated string"))] + #[error("encountered unterminated string")] UnterminatedString, - #[snafu(display("encountered trailing escape character"))] + #[error("encountered trailing escape character")] TrailingEscape, } @@ -125,8 +124,12 @@ impl LineDelimiter { /// Returns `true` if there is no remaining data to be read fn finish(&mut self) -> Result { if !self.remainder.is_empty() { - ensure!(!self.is_quote, UnterminatedStringSnafu); - ensure!(!self.is_escape, TrailingEscapeSnafu); + if self.is_quote { + Err(Error::UnterminatedString)?; + } + if self.is_escape { + Err(Error::TrailingEscape)?; + } self.complete .push_back(Bytes::from(std::mem::take(&mut self.remainder))) diff --git a/object_store/src/gcp/builder.rs b/object_store/src/gcp/builder.rs index fac923c4b9a0..cc5c1e1a0745 100644 --- a/object_store/src/gcp/builder.rs +++ b/object_store/src/gcp/builder.rs @@ -27,7 +27,6 @@ use crate::gcp::{ }; use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider}; use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -37,33 +36,33 @@ use super::credential::{AuthorizedUserSigningCredentials, InstanceSigningCredent const TOKEN_MIN_TTL: Duration = Duration::from_secs(4 * 60); -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Missing bucket name"))] + #[error("Missing bucket name")] MissingBucketName {}, - #[snafu(display("One of service account path or service account key may be provided."))] + #[error("One of service account path or service account key may be provided.")] ServiceAccountPathAndKeyProvided, - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] UnableToParseUrl { source: url::ParseError, url: String, }, - #[snafu(display( + #[error( "Unknown url scheme cannot be parsed into storage location: {}", scheme - ))] + )] UnknownUrlScheme { scheme: String }, - #[snafu(display("URL did not match any known pattern for scheme: {}", url))] + #[error("URL did not match any known pattern for scheme: {}", url)] UrlNotRecognised { url: String }, - #[snafu(display("Configuration key: '{}' is not known.", key))] + #[error("Configuration key: '{}' is not known.", key)] UnknownConfigurationKey { key: String }, - #[snafu(display("GCP credential error: {}", source))] + #[error("GCP credential error: {}", source)] Credential { source: credential::Error }, } @@ -319,12 +318,21 @@ impl GoogleCloudStorageBuilder { /// This is a separate member function to allow fallible computation to /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] fn parse_url(&mut self, url: &str) -> Result<()> { - let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; - let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; + let parsed = Url::parse(url).map_err(|source| Error::UnableToParseUrl { + source, + url: url.to_string(), + })?; + + let host = parsed.host_str().ok_or_else(|| Error::UrlNotRecognised { + url: url.to_string(), + })?; match parsed.scheme() { "gs" => self.bucket_name = Some(host.to_string()), - scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), + scheme => { + let scheme = scheme.to_string(); + return Err(Error::UnknownUrlScheme { scheme }.into()); + } } Ok(()) } @@ -428,12 +436,14 @@ impl GoogleCloudStorageBuilder { // First try to initialize from the service account information. let service_account_credentials = match (self.service_account_path, self.service_account_key) { - (Some(path), None) => { - Some(ServiceAccountCredentials::from_file(path).context(CredentialSnafu)?) - } - (None, Some(key)) => { - Some(ServiceAccountCredentials::from_key(&key).context(CredentialSnafu)?) - } + (Some(path), None) => Some( + ServiceAccountCredentials::from_file(path) + .map_err(|source| Error::Credential { source })?, + ), + (None, Some(key)) => Some( + ServiceAccountCredentials::from_key(&key) + .map_err(|source| Error::Credential { source })?, + ), (None, None) => None, (Some(_), Some(_)) => return Err(Error::ServiceAccountPathAndKeyProvided.into()), }; diff --git a/object_store/src/gcp/client.rs b/object_store/src/gcp/client.rs index b6c19a306ead..8dd1c69802a8 100644 --- a/object_store/src/gcp/client.rs +++ b/object_store/src/gcp/client.rs @@ -44,7 +44,6 @@ use percent_encoding::{percent_encode, utf8_percent_encode, NON_ALPHANUMERIC}; use reqwest::header::HeaderName; use reqwest::{Client, Method, RequestBuilder, Response, StatusCode}; use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; use std::sync::Arc; const VERSION_HEADER: &str = "x-goog-generation"; @@ -53,62 +52,62 @@ const USER_DEFINED_METADATA_HEADER_PREFIX: &str = "x-goog-meta-"; static VERSION_MATCH: HeaderName = HeaderName::from_static("x-goog-if-generation-match"); -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Error performing list request: {}", source))] + #[error("Error performing list request: {}", source)] ListRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting list response body: {}", source))] + #[error("Error getting list response body: {}", source)] ListResponseBody { source: reqwest::Error }, - #[snafu(display("Got invalid list response: {}", source))] + #[error("Got invalid list response: {}", source)] InvalidListResponse { source: quick_xml::de::DeError }, - #[snafu(display("Error performing get request {}: {}", path, source))] + #[error("Error performing get request {}: {}", path, source)] GetRequest { source: crate::client::retry::Error, path: String, }, - #[snafu(display("Error performing request {}: {}", path, source))] + #[error("Error performing request {}: {}", path, source)] Request { source: crate::client::retry::Error, path: String, }, - #[snafu(display("Error getting put response body: {}", source))] + #[error("Error getting put response body: {}", source)] PutResponseBody { source: reqwest::Error }, - #[snafu(display("Got invalid put request: {}", source))] + #[error("Got invalid put request: {}", source)] InvalidPutRequest { source: quick_xml::se::SeError }, - #[snafu(display("Got invalid put response: {}", source))] + #[error("Got invalid put response: {}", source)] InvalidPutResponse { source: quick_xml::de::DeError }, - #[snafu(display("Unable to extract metadata from headers: {}", source))] + #[error("Unable to extract metadata from headers: {}", source)] Metadata { source: crate::client::header::Error, }, - #[snafu(display("Version required for conditional update"))] + #[error("Version required for conditional update")] MissingVersion, - #[snafu(display("Error performing complete multipart request: {}", source))] + #[error("Error performing complete multipart request: {}", source)] CompleteMultipartRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting complete multipart response body: {}", source))] + #[error("Error getting complete multipart response body: {}", source)] CompleteMultipartResponseBody { source: reqwest::Error }, - #[snafu(display("Got invalid multipart response: {}", source))] + #[error("Got invalid multipart response: {}", source)] InvalidMultipartResponse { source: quick_xml::de::DeError }, - #[snafu(display("Error signing blob: {}", source))] + #[error("Error signing blob: {}", source)] SignBlobRequest { source: crate::client::retry::Error }, - #[snafu(display("Got invalid signing blob response: {}", source))] + #[error("Got invalid signing blob response: {}", source)] InvalidSignBlobResponse { source: reqwest::Error }, - #[snafu(display("Got invalid signing blob signature: {}", source))] + #[error("Got invalid signing blob signature: {}", source)] InvalidSignBlobSignature { source: base64::DecodeError }, } @@ -174,7 +173,7 @@ pub(crate) struct Request<'a> { idempotent: bool, } -impl<'a> Request<'a> { +impl Request<'_> { fn header(self, k: &HeaderName, v: &str) -> Self { let builder = self.builder.header(k, v); Self { builder, ..self } @@ -236,15 +235,17 @@ impl<'a> Request<'a> { .payload(self.payload) .send() .await - .context(RequestSnafu { - path: self.path.as_ref(), + .map_err(|source| { + let path = self.path.as_ref().into(); + Error::Request { source, path } })?; Ok(resp) } async fn do_put(self) -> Result { let response = self.send().await?; - Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) } } @@ -336,17 +337,17 @@ impl GoogleCloudStorageClient { .idempotent(true) .send() .await - .context(SignBlobRequestSnafu)?; + .map_err(|source| Error::SignBlobRequest { source })?; //If successful, the signature is returned in the signedBlob field in the response. let response = response .json::() .await - .context(InvalidSignBlobResponseSnafu)?; + .map_err(|source| Error::InvalidSignBlobResponse { source })?; let signed_blob = BASE64_STANDARD .decode(response.signed_blob) - .context(InvalidSignBlobSignatureSnafu)?; + .map_err(|source| Error::InvalidSignBlobSignature { source })?; Ok(hex_encode(&signed_blob)) } @@ -389,7 +390,7 @@ impl GoogleCloudStorageClient { PutMode::Overwrite => builder.idempotent(true), PutMode::Create => builder.header(&VERSION_MATCH, "0"), PutMode::Update(v) => { - let etag = v.version.as_ref().context(MissingVersionSnafu)?; + let etag = v.version.as_ref().ok_or(Error::MissingVersion)?; builder.header(&VERSION_MATCH, etag) } }; @@ -443,9 +444,14 @@ impl GoogleCloudStorageClient { .send() .await?; - let data = response.bytes().await.context(PutResponseBodySnafu)?; + let data = response + .bytes() + .await + .map_err(|source| Error::PutResponseBody { source })?; + let result: InitiateMultipartUploadResult = - quick_xml::de::from_reader(data.as_ref().reader()).context(InvalidPutResponseSnafu)?; + quick_xml::de::from_reader(data.as_ref().reader()) + .map_err(|source| Error::InvalidPutResponse { source })?; Ok(result.upload_id) } @@ -467,8 +473,9 @@ impl GoogleCloudStorageClient { .query(&[("uploadId", multipart_id)]) .send_retry(&self.config.retry_config) .await - .context(RequestSnafu { - path: path.as_ref(), + .map_err(|source| { + let path = path.as_ref().into(); + Error::Request { source, path } })?; Ok(()) @@ -498,7 +505,7 @@ impl GoogleCloudStorageClient { let credential = self.get_credential().await?; let data = quick_xml::se::to_string(&upload_info) - .context(InvalidPutRequestSnafu)? + .map_err(|source| Error::InvalidPutRequest { source })? // We cannot disable the escaping that transforms "/" to ""e;" :( // https://github.com/tafia/quick-xml/issues/362 // https://github.com/tafia/quick-xml/issues/350 @@ -514,17 +521,18 @@ impl GoogleCloudStorageClient { .idempotent(true) .send() .await - .context(CompleteMultipartRequestSnafu)?; + .map_err(|source| Error::CompleteMultipartRequest { source })?; - let version = get_version(response.headers(), VERSION_HEADER).context(MetadataSnafu)?; + let version = get_version(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?; let data = response .bytes() .await - .context(CompleteMultipartResponseBodySnafu)?; + .map_err(|source| Error::CompleteMultipartResponseBody { source })?; - let response: CompleteMultipartUploadResult = - quick_xml::de::from_reader(data.reader()).context(InvalidMultipartResponseSnafu)?; + let response: CompleteMultipartUploadResult = quick_xml::de::from_reader(data.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; Ok(PutResult { e_tag: Some(response.e_tag), @@ -615,8 +623,9 @@ impl GetClient for GoogleCloudStorageClient { .with_get_options(options) .send_retry(&self.config.retry_config) .await - .context(GetRequestSnafu { - path: path.as_ref(), + .map_err(|source| { + let path = path.as_ref().into(); + Error::GetRequest { source, path } })?; Ok(response) @@ -665,13 +674,13 @@ impl ListClient for Arc { .bearer_auth(&credential.bearer) .send_retry(&self.config.retry_config) .await - .context(ListRequestSnafu)? + .map_err(|source| Error::ListRequest { source })? .bytes() .await - .context(ListResponseBodySnafu)?; + .map_err(|source| Error::ListResponseBody { source })?; - let mut response: ListResponse = - quick_xml::de::from_reader(response.reader()).context(InvalidListResponseSnafu)?; + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidListResponse { source })?; let token = response.next_continuation_token.take(); Ok((response.try_into()?, token)) diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs index 155a80b343b2..4b21ad1d3eab 100644 --- a/object_store/src/gcp/credential.rs +++ b/object_store/src/gcp/credential.rs @@ -33,7 +33,6 @@ use percent_encoding::utf8_percent_encode; use reqwest::{Client, Method}; use ring::signature::RsaKeyPair; use serde::Deserialize; -use snafu::{ResultExt, Snafu}; use std::collections::BTreeMap; use std::env; use std::fs::File; @@ -54,36 +53,39 @@ const DEFAULT_GCS_SIGN_BLOB_HOST: &str = "storage.googleapis.com"; const DEFAULT_METADATA_HOST: &str = "metadata.google.internal"; const DEFAULT_METADATA_IP: &str = "169.254.169.254"; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub enum Error { - #[snafu(display("Unable to open service account file from {}: {}", path.display(), source))] + #[error("Unable to open service account file from {}: {}", path.display(), source)] OpenCredentials { source: std::io::Error, path: PathBuf, }, - #[snafu(display("Unable to decode service account file: {}", source))] + #[error("Unable to decode service account file: {}", source)] DecodeCredentials { source: serde_json::Error }, - #[snafu(display("No RSA key found in pem file"))] + #[error("No RSA key found in pem file")] MissingKey, - #[snafu(display("Invalid RSA key: {}", source), context(false))] - InvalidKey { source: ring::error::KeyRejected }, + #[error("Invalid RSA key: {}", source)] + InvalidKey { + #[from] + source: ring::error::KeyRejected, + }, - #[snafu(display("Error signing: {}", source))] + #[error("Error signing: {}", source)] Sign { source: ring::error::Unspecified }, - #[snafu(display("Error encoding jwt payload: {}", source))] + #[error("Error encoding jwt payload: {}", source)] Encode { source: serde_json::Error }, - #[snafu(display("Unsupported key encoding: {}", encoding))] + #[error("Unsupported key encoding: {}", encoding)] UnsupportedKey { encoding: String }, - #[snafu(display("Error performing token request: {}", source))] + #[error("Error performing token request: {}", source)] TokenRequest { source: crate::client::retry::Error }, - #[snafu(display("Error getting token response body: {}", source))] + #[error("Error getting token response body: {}", source)] TokenResponseBody { source: reqwest::Error }, } @@ -153,7 +155,7 @@ impl ServiceAccountKey { string_to_sign.as_bytes(), &mut signature, ) - .context(SignSnafu)?; + .map_err(|source| Error::Sign { source })?; Ok(hex_encode(&signature)) } @@ -289,7 +291,7 @@ impl TokenProvider for SelfSignedJwt { message.as_bytes(), &mut sig_bytes, ) - .context(SignSnafu)?; + .map_err(|source| Error::Sign { source })?; let signature = BASE64_URL_SAFE_NO_PAD.encode(sig_bytes); let bearer = [message, signature].join("."); @@ -305,11 +307,12 @@ fn read_credentials_file(service_account_path: impl AsRef) - where T: serde::de::DeserializeOwned, { - let file = File::open(&service_account_path).context(OpenCredentialsSnafu { - path: service_account_path.as_ref().to_owned(), + let file = File::open(&service_account_path).map_err(|source| { + let path = service_account_path.as_ref().to_owned(); + Error::OpenCredentials { source, path } })?; let reader = BufReader::new(file); - serde_json::from_reader(reader).context(DecodeCredentialsSnafu) + serde_json::from_reader(reader).map_err(|source| Error::DecodeCredentials { source }) } /// A deserialized `service-account-********.json`-file. @@ -341,7 +344,7 @@ impl ServiceAccountCredentials { /// Create a new [`ServiceAccountCredentials`] from a string. pub(crate) fn from_key(key: &str) -> Result { - serde_json::from_str(key).context(DecodeCredentialsSnafu) + serde_json::from_str(key).map_err(|source| Error::DecodeCredentials { source }) } /// Create a [`SelfSignedJwt`] from this credentials struct. @@ -380,7 +383,7 @@ fn seconds_since_epoch() -> u64 { } fn b64_encode_obj(obj: &T) -> Result { - let string = serde_json::to_string(obj).context(EncodeSnafu)?; + let string = serde_json::to_string(obj).map_err(|source| Error::Encode { source })?; Ok(BASE64_URL_SAFE_NO_PAD.encode(string)) } @@ -404,10 +407,10 @@ async fn make_metadata_request( .query(&[("audience", "https://www.googleapis.com/oauth2/v4/token")]) .send_retry(retry) .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .json() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(response) } @@ -467,10 +470,10 @@ async fn make_metadata_request_for_email( .header("Metadata-Flavor", "Google") .send_retry(retry) .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .text() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(response) } @@ -608,10 +611,10 @@ impl AuthorizedUserSigningCredentials { .query(&[("access_token", &self.credential.refresh_token)]) .send_retry(retry) .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .json::() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(response.email) } @@ -659,10 +662,10 @@ impl TokenProvider for AuthorizedUserCredentials { .idempotent(true) .send() .await - .context(TokenRequestSnafu)? + .map_err(|source| Error::TokenRequest { source })? .json::() .await - .context(TokenResponseBodySnafu)?; + .map_err(|source| Error::TokenResponseBody { source })?; Ok(TemporaryToken { token: Arc::new(GcpCredential { diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 00823b9c487b..a2f512415a8d 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -297,6 +297,7 @@ mod test { // https://github.com/fsouza/fake-gcs-server/issues/852 stream_get(&integration).await; multipart(&integration, &integration).await; + multipart_race_condition(&integration, true).await; // Fake GCS server doesn't currently honor preconditions get_opts(&integration).await; put_opts(&integration, true).await; diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index eeb7e5694228..41e6464c1999 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -32,42 +32,41 @@ use hyper::header::{ use percent_encoding::percent_decode_str; use reqwest::{Method, Response, StatusCode}; use serde::Deserialize; -use snafu::{OptionExt, ResultExt, Snafu}; use url::Url; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Request error: {}", source))] + #[error("Request error: {}", source)] Request { source: retry::Error }, - #[snafu(display("Request error: {}", source))] + #[error("Request error: {}", source)] Reqwest { source: reqwest::Error }, - #[snafu(display("Range request not supported by {}", href))] + #[error("Range request not supported by {}", href)] RangeNotSupported { href: String }, - #[snafu(display("Error decoding PROPFIND response: {}", source))] + #[error("Error decoding PROPFIND response: {}", source)] InvalidPropFind { source: quick_xml::de::DeError }, - #[snafu(display("Missing content size for {}", href))] + #[error("Missing content size for {}", href)] MissingSize { href: String }, - #[snafu(display("Error getting properties of \"{}\" got \"{}\"", href, status))] + #[error("Error getting properties of \"{}\" got \"{}\"", href, status)] PropStatus { href: String, status: String }, - #[snafu(display("Failed to parse href \"{}\": {}", href, source))] + #[error("Failed to parse href \"{}\": {}", href, source)] InvalidHref { href: String, source: url::ParseError, }, - #[snafu(display("Path \"{}\" contained non-unicode characters: {}", path, source))] + #[error("Path \"{}\" contained non-unicode characters: {}", path, source)] NonUnicode { path: String, source: std::str::Utf8Error, }, - #[snafu(display("Encountered invalid path \"{}\": {}", path, source))] + #[error("Encountered invalid path \"{}\": {}", path, source)] InvalidPath { path: String, source: crate::path::Error, @@ -129,7 +128,7 @@ impl Client { .request(method, url) .send_retry(&self.retry_config) .await - .context(RequestSnafu)?; + .map_err(|source| Error::Request { source })?; Ok(()) } @@ -236,7 +235,10 @@ impl Client { .await; let response = match result { - Ok(result) => result.bytes().await.context(ReqwestSnafu)?, + Ok(result) => result + .bytes() + .await + .map_err(|source| Error::Reqwest { source })?, Err(e) if matches!(e.status(), Some(StatusCode::NOT_FOUND)) => { return match depth { "0" => { @@ -255,7 +257,9 @@ impl Client { Err(source) => return Err(Error::Request { source }.into()), }; - let status = quick_xml::de::from_reader(response.reader()).context(InvalidPropFindSnafu)?; + let status = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidPropFind { source })?; + Ok(status) } @@ -397,14 +401,23 @@ impl MultiStatusResponse { let url = Url::options() .base_url(Some(base_url)) .parse(&self.href) - .context(InvalidHrefSnafu { href: &self.href })?; + .map_err(|source| Error::InvalidHref { + href: self.href.clone(), + source, + })?; // Reverse any percent encoding let path = percent_decode_str(url.path()) .decode_utf8() - .context(NonUnicodeSnafu { path: url.path() })?; + .map_err(|source| Error::NonUnicode { + path: url.path().into(), + source, + })?; - Ok(Path::parse(path.as_ref()).context(InvalidPathSnafu { path })?) + Ok(Path::parse(path.as_ref()).map_err(|source| { + let path = path.into(); + Error::InvalidPath { path, source } + })?) } fn size(&self) -> Result { @@ -412,7 +425,10 @@ impl MultiStatusResponse { .prop_stat .prop .content_length - .context(MissingSizeSnafu { href: &self.href })?; + .ok_or_else(|| Error::MissingSize { + href: self.href.clone(), + })?; + Ok(size) } diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index 5ea890362751..899740d36db9 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -37,7 +37,6 @@ use async_trait::async_trait; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; -use snafu::{OptionExt, ResultExt, Snafu}; use url::Url; use crate::client::get::GetClientExt; @@ -51,18 +50,18 @@ use crate::{ mod client; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("Must specify a URL"))] + #[error("Must specify a URL")] MissingUrl, - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] UnableToParseUrl { source: url::ParseError, url: String, }, - #[snafu(display("Unable to extract metadata from headers: {}", source))] + #[error("Unable to extract metadata from headers: {}", source)] Metadata { source: crate::client::header::Error, }, @@ -238,8 +237,8 @@ impl HttpBuilder { /// Build an [`HttpStore`] with the configured options pub fn build(self) -> Result { - let url = self.url.context(MissingUrlSnafu)?; - let parsed = Url::parse(&url).context(UnableToParseUrlSnafu { url })?; + let url = self.url.ok_or(Error::MissingUrl)?; + let parsed = Url::parse(&url).map_err(|source| Error::UnableToParseUrl { url, source })?; Ok(HttpStore { client: Arc::new(Client::new(parsed, self.client_options, self.retry_config)?), diff --git a/object_store/src/integration.rs b/object_store/src/integration.rs index 30177878306f..20e95fddc478 100644 --- a/object_store/src/integration.rs +++ b/object_store/src/integration.rs @@ -24,6 +24,8 @@ //! //! They are intended solely for testing purposes. +use core::str; + use crate::multipart::MultipartStore; use crate::path::Path; use crate::{ @@ -1109,3 +1111,88 @@ async fn delete_fixtures(storage: &DynObjectStore) { .await .unwrap(); } + +/// Tests a race condition where 2 threads are performing multipart writes to the same path +pub async fn multipart_race_condition(storage: &dyn ObjectStore, last_writer_wins: bool) { + let path = Path::from("test_multipart_race_condition"); + + let mut multipart_upload_1 = storage.put_multipart(&path).await.unwrap(); + let mut multipart_upload_2 = storage.put_multipart(&path).await.unwrap(); + + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 0)).into()) + .await + .unwrap(); + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 0)).into()) + .await + .unwrap(); + + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 1)).into()) + .await + .unwrap(); + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 1)).into()) + .await + .unwrap(); + + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 2)).into()) + .await + .unwrap(); + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 2)).into()) + .await + .unwrap(); + + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 3)).into()) + .await + .unwrap(); + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 3)).into()) + .await + .unwrap(); + + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 4)).into()) + .await + .unwrap(); + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 4)).into()) + .await + .unwrap(); + + multipart_upload_1.complete().await.unwrap(); + + if last_writer_wins { + multipart_upload_2.complete().await.unwrap(); + } else { + let err = multipart_upload_2.complete().await.unwrap_err(); + + assert!(matches!(err, crate::Error::Generic { .. }), "{err}"); + } + + let get_result = storage.get(&path).await.unwrap(); + let bytes = get_result.bytes().await.unwrap(); + let string_contents = str::from_utf8(&bytes).unwrap(); + + if last_writer_wins { + assert!(string_contents.starts_with( + format!( + "2:{:05300000},2:{:05300000},2:{:05300000},2:{:05300000},2:{:05300000},", + 0, 1, 2, 3, 4 + ) + .as_str() + )); + } else { + assert!(string_contents.starts_with( + format!( + "1:{:05300000},1:{:05300000},1:{:05300000},1:{:05300000},1:{:05300000},", + 0, 1, 2, 3, 4 + ) + .as_str() + )); + } +} diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index f19cca373be3..53eda5a82fd5 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -66,10 +66,13 @@ //! By default, this crate provides the following implementations: //! //! * Memory: [`InMemory`](memory::InMemory) -//! * Local filesystem: [`LocalFileSystem`](local::LocalFileSystem) //! //! Feature flags are used to enable support for other implementations: //! +#![cfg_attr( + feature = "fs", + doc = "* Local filesystem: [`LocalFileSystem`](local::LocalFileSystem)" +)] #![cfg_attr( feature = "gcp", doc = "* [`gcp`]: [Google Cloud Storage](https://cloud.google.com/storage/) support. See [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder)" @@ -513,7 +516,7 @@ pub mod gcp; #[cfg(feature = "http")] pub mod http; pub mod limit; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] pub mod local; pub mod memory; pub mod path; @@ -557,15 +560,14 @@ pub use upload::*; pub use util::{coalesce_ranges, collect_bytes, GetRange, OBJECT_STORE_COALESCE_DEFAULT}; use crate::path::Path; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] use crate::util::maybe_spawn_blocking; use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; -use snafu::Snafu; use std::fmt::{Debug, Formatter}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; use std::sync::Arc; @@ -1028,6 +1030,7 @@ pub struct GetResult { /// be able to optimise the case of a file already present on local disk pub enum GetResultPayload { /// The file, path + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] File(std::fs::File, std::path::PathBuf), /// An opaque stream of bytes Stream(BoxStream<'static, Result>), @@ -1036,6 +1039,7 @@ pub enum GetResultPayload { impl Debug for GetResultPayload { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] Self::File(_, _) => write!(f, "GetResultPayload(File)"), Self::Stream(_) => write!(f, "GetResultPayload(Stream)"), } @@ -1047,7 +1051,7 @@ impl GetResult { pub async fn bytes(self) -> Result { let len = self.range.end - self.range.start; match self.payload { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] GetResultPayload::File(mut file, path) => { maybe_spawn_blocking(move || { file.seek(SeekFrom::Start(self.range.start as _)) @@ -1087,7 +1091,7 @@ impl GetResult { /// no additional complexity or overheads pub fn into_stream(self) -> BoxStream<'static, Result> { match self.payload { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] GetResultPayload::File(file, path) => { const CHUNK_SIZE: usize = 8 * 1024; local::chunked_stream(file, path, self.range, CHUNK_SIZE) @@ -1224,11 +1228,11 @@ pub struct PutResult { pub type Result = std::result::Result; /// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum Error { /// A fallback error type when no variant matches - #[snafu(display("Generic {} error: {}", store, source))] + #[error("Generic {} error: {}", store, source)] Generic { /// The store this error originated from store: &'static str, @@ -1237,7 +1241,7 @@ pub enum Error { }, /// Error when the object is not found at given location - #[snafu(display("Object at location {} not found: {}", path, source))] + #[error("Object at location {} not found: {}", path, source)] NotFound { /// The path to file path: String, @@ -1246,31 +1250,30 @@ pub enum Error { }, /// Error for invalid path - #[snafu( - display("Encountered object with invalid path: {}", source), - context(false) - )] + #[error("Encountered object with invalid path: {}", source)] InvalidPath { /// The wrapped error + #[from] source: path::Error, }, /// Error when `tokio::spawn` failed - #[snafu(display("Error joining spawned task: {}", source), context(false))] + #[error("Error joining spawned task: {}", source)] JoinError { /// The wrapped error + #[from] source: tokio::task::JoinError, }, /// Error when the attempted operation is not supported - #[snafu(display("Operation not supported: {}", source))] + #[error("Operation not supported: {}", source)] NotSupported { /// The wrapped error source: Box, }, /// Error when the object already exists - #[snafu(display("Object at location {} already exists: {}", path, source))] + #[error("Object at location {} already exists: {}", path, source)] AlreadyExists { /// The path to the path: String, @@ -1279,7 +1282,7 @@ pub enum Error { }, /// Error when the required conditions failed for the operation - #[snafu(display("Request precondition failure for path {}: {}", path, source))] + #[error("Request precondition failure for path {}: {}", path, source)] Precondition { /// The path to the file path: String, @@ -1288,7 +1291,7 @@ pub enum Error { }, /// Error when the object at the location isn't modified - #[snafu(display("Object at location {} not modified: {}", path, source))] + #[error("Object at location {} not modified: {}", path, source)] NotModified { /// The path to the file path: String, @@ -1297,16 +1300,16 @@ pub enum Error { }, /// Error when an operation is not implemented - #[snafu(display("Operation not yet implemented."))] + #[error("Operation not yet implemented.")] NotImplemented, /// Error when the used credentials don't have enough permission /// to perform the requested operation - #[snafu(display( + #[error( "The operation lacked the necessary privileges to complete for path {}: {}", path, source - ))] + )] PermissionDenied { /// The path to the file path: String, @@ -1315,11 +1318,11 @@ pub enum Error { }, /// Error when the used credentials lack valid authentication - #[snafu(display( + #[error( "The operation lacked valid authentication credentials for path {}: {}", path, source - ))] + )] Unauthenticated { /// The path to the file path: String, @@ -1328,7 +1331,7 @@ pub enum Error { }, /// Error when a configuration key is invalid for the store used - #[snafu(display("Configuration key: '{}' is not valid for store '{}'.", key, store))] + #[error("Configuration key: '{}' is not valid for store '{}'.", key, store)] UnknownConfigurationKey { /// The object store used store: &'static str, diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index 4fae271bcb71..77f72a0e11a1 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -201,6 +201,7 @@ impl ObjectStore for LimitStore { fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult { let payload = match r.payload { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] v @ GetResultPayload::File(_, _) => v, GetResultPayload::Stream(s) => { GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed()) diff --git a/object_store/src/local.rs b/object_store/src/local.rs index cb3da28ddeef..364026459a03 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -30,7 +30,6 @@ use chrono::{DateTime, Utc}; use futures::{stream::BoxStream, StreamExt}; use futures::{FutureExt, TryStreamExt}; use parking_lot::Mutex; -use snafu::{ensure, OptionExt, ResultExt, Snafu}; use url::Url; use walkdir::{DirEntry, WalkDir}; @@ -43,117 +42,80 @@ use crate::{ }; /// A specialized `Error` for filesystem object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub(crate) enum Error { - #[snafu(display("File size for {} did not fit in a usize: {}", path, source))] + #[error("File size for {} did not fit in a usize: {}", path, source)] FileSizeOverflowedUsize { source: std::num::TryFromIntError, path: String, }, - #[snafu(display("Unable to walk dir: {}", source))] - UnableToWalkDir { - source: walkdir::Error, - }, + #[error("Unable to walk dir: {}", source)] + UnableToWalkDir { source: walkdir::Error }, - #[snafu(display("Unable to access metadata for {}: {}", path, source))] + #[error("Unable to access metadata for {}: {}", path, source)] Metadata { source: Box, path: String, }, - #[snafu(display("Unable to copy data to file: {}", source))] - UnableToCopyDataToFile { - source: io::Error, - }, + #[error("Unable to copy data to file: {}", source)] + UnableToCopyDataToFile { source: io::Error }, - #[snafu(display("Unable to rename file: {}", source))] - UnableToRenameFile { - source: io::Error, - }, + #[error("Unable to rename file: {}", source)] + UnableToRenameFile { source: io::Error }, - #[snafu(display("Unable to create dir {}: {}", path.display(), source))] - UnableToCreateDir { - source: io::Error, - path: PathBuf, - }, + #[error("Unable to create dir {}: {}", path.display(), source)] + UnableToCreateDir { source: io::Error, path: PathBuf }, - #[snafu(display("Unable to create file {}: {}", path.display(), source))] - UnableToCreateFile { - source: io::Error, - path: PathBuf, - }, + #[error("Unable to create file {}: {}", path.display(), source)] + UnableToCreateFile { source: io::Error, path: PathBuf }, - #[snafu(display("Unable to delete file {}: {}", path.display(), source))] - UnableToDeleteFile { - source: io::Error, - path: PathBuf, - }, + #[error("Unable to delete file {}: {}", path.display(), source)] + UnableToDeleteFile { source: io::Error, path: PathBuf }, - #[snafu(display("Unable to open file {}: {}", path.display(), source))] - UnableToOpenFile { - source: io::Error, - path: PathBuf, - }, + #[error("Unable to open file {}: {}", path.display(), source)] + UnableToOpenFile { source: io::Error, path: PathBuf }, - #[snafu(display("Unable to read data from file {}: {}", path.display(), source))] - UnableToReadBytes { - source: io::Error, - path: PathBuf, - }, + #[error("Unable to read data from file {}: {}", path.display(), source)] + UnableToReadBytes { source: io::Error, path: PathBuf }, - #[snafu(display("Out of range of file {}, expected: {}, actual: {}", path.display(), expected, actual))] + #[error("Out of range of file {}, expected: {}, actual: {}", path.display(), expected, actual)] OutOfRange { path: PathBuf, expected: usize, actual: usize, }, - #[snafu(display("Requested range was invalid"))] - InvalidRange { - source: InvalidGetRange, - }, + #[error("Requested range was invalid")] + InvalidRange { source: InvalidGetRange }, - #[snafu(display("Unable to copy file from {} to {}: {}", from.display(), to.display(), source))] + #[error("Unable to copy file from {} to {}: {}", from.display(), to.display(), source)] UnableToCopyFile { from: PathBuf, to: PathBuf, source: io::Error, }, - NotFound { - path: PathBuf, - source: io::Error, - }, + #[error("NotFound")] + NotFound { path: PathBuf, source: io::Error }, - #[snafu(display("Error seeking file {}: {}", path.display(), source))] - Seek { - source: io::Error, - path: PathBuf, - }, + #[error("Error seeking file {}: {}", path.display(), source)] + Seek { source: io::Error, path: PathBuf }, - #[snafu(display("Unable to convert URL \"{}\" to filesystem path", url))] - InvalidUrl { - url: Url, - }, + #[error("Unable to convert URL \"{}\" to filesystem path", url)] + InvalidUrl { url: Url }, - AlreadyExists { - path: String, - source: io::Error, - }, + #[error("AlreadyExists")] + AlreadyExists { path: String, source: io::Error }, - #[snafu(display("Unable to canonicalize filesystem root: {}", path.display()))] - UnableToCanonicalize { - path: PathBuf, - source: io::Error, - }, + #[error("Unable to canonicalize filesystem root: {}", path.display())] + UnableToCanonicalize { path: PathBuf, source: io::Error }, - #[snafu(display("Filenames containing trailing '/#\\d+/' are not supported: {}", path))] - InvalidPath { - path: String, - }, + #[error("Filenames containing trailing '/#\\d+/' are not supported: {}", path)] + InvalidPath { path: String }, - #[snafu(display("Upload aborted"))] + #[error("Upload aborted")] Aborted, } @@ -276,8 +238,9 @@ impl LocalFileSystem { /// Returns an error if the path does not exist /// pub fn new_with_prefix(prefix: impl AsRef) -> Result { - let path = std::fs::canonicalize(&prefix).context(UnableToCanonicalizeSnafu { - path: prefix.as_ref(), + let path = std::fs::canonicalize(&prefix).map_err(|source| { + let path = prefix.as_ref().into(); + Error::UnableToCanonicalize { source, path } })?; Ok(Self { @@ -290,12 +253,12 @@ impl LocalFileSystem { /// Return an absolute filesystem path of the given file location pub fn path_to_filesystem(&self, location: &Path) -> Result { - ensure!( - is_valid_file_path(location), - InvalidPathSnafu { - path: location.as_ref() - } - ); + if !is_valid_file_path(location) { + let path = location.as_ref().into(); + let error = Error::InvalidPath { path }; + return Err(error.into()); + } + let path = self.config.prefix_to_filesystem(location)?; #[cfg(target_os = "windows")] @@ -451,7 +414,9 @@ impl ObjectStore for LocalFileSystem { options.check_preconditions(&meta)?; let range = match options.range { - Some(r) => r.as_range(meta.size).context(InvalidRangeSnafu)?, + Some(r) => r + .as_range(meta.size) + .map_err(|source| Error::InvalidRange { source })?, None => 0..meta.size, }; @@ -721,12 +686,15 @@ impl ObjectStore for LocalFileSystem { /// Creates the parent directories of `path` or returns an error based on `source` if no parent fn create_parent_dirs(path: &std::path::Path, source: io::Error) -> Result<()> { - let parent = path.parent().ok_or_else(|| Error::UnableToCreateFile { - path: path.to_path_buf(), - source, + let parent = path.parent().ok_or_else(|| { + let path = path.to_path_buf(); + Error::UnableToCreateFile { path, source } })?; - std::fs::create_dir_all(parent).context(UnableToCreateDirSnafu { path: parent })?; + std::fs::create_dir_all(parent).map_err(|source| { + let path = parent.into(); + Error::UnableToCreateDir { source, path } + })?; Ok(()) } @@ -796,12 +764,14 @@ impl MultipartUpload for LocalUpload { let s = Arc::clone(&self.state); maybe_spawn_blocking(move || { let mut file = s.file.lock(); - file.seek(SeekFrom::Start(offset)) - .context(SeekSnafu { path: &s.dest })?; + file.seek(SeekFrom::Start(offset)).map_err(|source| { + let path = s.dest.clone(); + Error::Seek { source, path } + })?; data.iter() .try_for_each(|x| file.write_all(x)) - .context(UnableToCopyDataToFileSnafu)?; + .map_err(|source| Error::UnableToCopyDataToFile { source })?; Ok(()) }) @@ -809,12 +779,13 @@ impl MultipartUpload for LocalUpload { } async fn complete(&mut self) -> Result { - let src = self.src.take().context(AbortedSnafu)?; + let src = self.src.take().ok_or(Error::Aborted)?; let s = Arc::clone(&self.state); maybe_spawn_blocking(move || { // Ensure no inflight writes let file = s.file.lock(); - std::fs::rename(&src, &s.dest).context(UnableToRenameFileSnafu)?; + std::fs::rename(&src, &s.dest) + .map_err(|source| Error::UnableToRenameFile { source })?; let metadata = file.metadata().map_err(|e| Error::Metadata { source: e.into(), path: src.to_string_lossy().to_string(), @@ -829,9 +800,10 @@ impl MultipartUpload for LocalUpload { } async fn abort(&mut self) -> Result<()> { - let src = self.src.take().context(AbortedSnafu)?; + let src = self.src.take().ok_or(Error::Aborted)?; maybe_spawn_blocking(move || { - std::fs::remove_file(&src).context(UnableToDeleteFileSnafu { path: &src })?; + std::fs::remove_file(&src) + .map_err(|source| Error::UnableToDeleteFile { source, path: src })?; Ok(()) }) .await @@ -898,22 +870,30 @@ pub(crate) fn chunked_stream( pub(crate) fn read_range(file: &mut File, path: &PathBuf, range: Range) -> Result { let to_read = range.end - range.start; file.seek(SeekFrom::Start(range.start as u64)) - .context(SeekSnafu { path })?; + .map_err(|source| { + let path = path.into(); + Error::Seek { source, path } + })?; let mut buf = Vec::with_capacity(to_read); let read = file .take(to_read as u64) .read_to_end(&mut buf) - .context(UnableToReadBytesSnafu { path })?; + .map_err(|source| { + let path = path.into(); + Error::UnableToReadBytes { source, path } + })?; - ensure!( - read == to_read, - OutOfRangeSnafu { - path, + if read != to_read { + let error = Error::OutOfRange { + path: path.into(), expected: to_read, - actual: read - } - ); + actual: read, + }; + + return Err(error.into()); + } + Ok(buf.into()) } @@ -982,8 +962,9 @@ fn get_etag(metadata: &Metadata) -> String { fn convert_metadata(metadata: Metadata, location: Path) -> Result { let last_modified = last_modified(&metadata); - let size = usize::try_from(metadata.len()).context(FileSizeOverflowedUsizeSnafu { - path: location.as_ref(), + let size = usize::try_from(metadata.len()).map_err(|source| { + let path = location.as_ref().into(); + Error::FileSizeOverflowedUsize { source, path } })?; Ok(ObjectMeta { @@ -1004,7 +985,7 @@ fn get_inode(metadata: &Metadata) -> u64 { #[cfg(not(unix))] /// On platforms where an inode isn't available, fallback to just relying on size and mtime -fn get_inode(metadata: &Metadata) -> u64 { +fn get_inode(_metadata: &Metadata) -> u64 { 0 } @@ -1060,7 +1041,10 @@ mod tests { use std::fs; use futures::TryStreamExt; - use tempfile::{NamedTempFile, TempDir}; + use tempfile::TempDir; + + #[cfg(target_family = "unix")] + use tempfile::NamedTempFile; use crate::integration::*; @@ -1248,6 +1232,7 @@ mod tests { fs.list_with_delimiter(None).await.unwrap(); } + #[cfg(target_family = "unix")] async fn check_list(integration: &LocalFileSystem, prefix: Option<&Path>, expected: &[&str]) { let result: Vec<_> = integration.list(prefix).try_collect().await.unwrap(); diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index f6ec57fad747..6402f924346f 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -25,7 +25,6 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::{stream::BoxStream, StreamExt}; use parking_lot::RwLock; -use snafu::{OptionExt, ResultExt, Snafu}; use crate::multipart::{MultipartStore, PartId}; use crate::util::InvalidGetRange; @@ -37,24 +36,24 @@ use crate::{ use crate::{GetOptions, PutPayload}; /// A specialized `Error` for in-memory object store-related errors -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] enum Error { - #[snafu(display("No data in memory found. Location: {path}"))] + #[error("No data in memory found. Location: {path}")] NoDataInMemory { path: String }, - #[snafu(display("Invalid range: {source}"))] + #[error("Invalid range: {source}")] Range { source: InvalidGetRange }, - #[snafu(display("Object already exists at that location: {path}"))] + #[error("Object already exists at that location: {path}")] AlreadyExists { path: String }, - #[snafu(display("ETag required for conditional update"))] + #[error("ETag required for conditional update")] MissingETag, - #[snafu(display("MultipartUpload not found: {id}"))] + #[error("MultipartUpload not found: {id}")] UploadNotFound { id: String }, - #[snafu(display("Missing part at index: {part}"))] + #[error("Missing part at index: {part}")] MissingPart { part: usize }, } @@ -158,7 +157,7 @@ impl Storage { }), Some(e) => { let existing = e.e_tag.to_string(); - let expected = v.e_tag.context(MissingETagSnafu)?; + let expected = v.e_tag.ok_or(Error::MissingETag)?; if existing == expected { *e = entry; Ok(()) @@ -177,7 +176,7 @@ impl Storage { .parse() .ok() .and_then(|x| self.uploads.get_mut(&x)) - .context(UploadNotFoundSnafu { id })?; + .ok_or_else(|| Error::UploadNotFound { id: id.into() })?; Ok(parts) } @@ -186,7 +185,7 @@ impl Storage { .parse() .ok() .and_then(|x| self.uploads.remove(&x)) - .context(UploadNotFoundSnafu { id })?; + .ok_or_else(|| Error::UploadNotFound { id: id.into() })?; Ok(parts) } } @@ -250,7 +249,9 @@ impl ObjectStore for InMemory { let (range, data) = match options.range { Some(range) => { - let r = range.as_range(entry.data.len()).context(RangeSnafu)?; + let r = range + .as_range(entry.data.len()) + .map_err(|source| Error::Range { source })?; (r.clone(), entry.data.slice(r)) } None => (0..entry.data.len(), entry.data), @@ -272,7 +273,7 @@ impl ObjectStore for InMemory { .map(|range| { let r = GetRange::Bounded(range.clone()) .as_range(entry.data.len()) - .context(RangeSnafu)?; + .map_err(|source| Error::Range { source })?; Ok(entry.data.slice(r)) }) @@ -435,7 +436,7 @@ impl MultipartStore for InMemory { let mut cap = 0; for (part, x) in upload.parts.iter().enumerate() { - cap += x.as_ref().context(MissingPartSnafu { part })?.len(); + cap += x.as_ref().ok_or(Error::MissingPart { part })?.len(); } let mut buf = Vec::with_capacity(cap); for x in &upload.parts { @@ -468,19 +469,13 @@ impl InMemory { Self { storage } } - /// Creates a clone of the store - #[deprecated(note = "Use fork() instead")] - pub async fn clone(&self) -> Self { - self.fork() - } - async fn entry(&self, location: &Path) -> Result { let storage = self.storage.read(); let value = storage .map .get(location) .cloned() - .context(NoDataInMemorySnafu { + .ok_or_else(|| Error::NoDataInMemory { path: location.to_string(), })?; diff --git a/object_store/src/parse.rs b/object_store/src/parse.rs index debc9e529312..bc65a0b8d1c8 100644 --- a/object_store/src/parse.rs +++ b/object_store/src/parse.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] use crate::local::LocalFileSystem; use crate::memory::InMemory; use crate::path::Path; use crate::ObjectStore; -use snafu::Snafu; use url::Url; -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub enum Error { - #[snafu(display("Unable to recognise URL \"{}\"", url))] + #[error("Unable to recognise URL \"{}\"", url)] Unrecognised { url: Url }, - #[snafu(context(false))] - Path { source: crate::path::Error }, + #[error(transparent)] + Path { + #[from] + source: crate::path::Error, + }, } impl From for super::Error { @@ -179,7 +181,7 @@ where let path = Path::parse(path)?; let store = match scheme { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] ObjectStoreScheme::Local => Box::new(LocalFileSystem::new()) as _, ObjectStoreScheme::Memory => Box::new(InMemory::new()) as _, #[cfg(feature = "aws")] diff --git a/object_store/src/path/mod.rs b/object_store/src/path/mod.rs index 4c9bb5f05186..f8affe8dfbb9 100644 --- a/object_store/src/path/mod.rs +++ b/object_store/src/path/mod.rs @@ -19,7 +19,6 @@ use itertools::Itertools; use percent_encoding::percent_decode; -use snafu::{ensure, ResultExt, Snafu}; use std::fmt::Formatter; #[cfg(not(target_arch = "wasm32"))] use url::Url; @@ -35,18 +34,18 @@ mod parts; pub use parts::{InvalidPart, PathPart}; /// Error returned by [`Path::parse`] -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum Error { /// Error when there's an empty segment between two slashes `/` in the path - #[snafu(display("Path \"{}\" contained empty path segment", path))] + #[error("Path \"{}\" contained empty path segment", path)] EmptySegment { /// The source path path: String, }, /// Error when an invalid segment is encountered in the given path - #[snafu(display("Error parsing Path \"{}\": {}", path, source))] + #[error("Error parsing Path \"{}\": {}", path, source)] BadSegment { /// The source path path: String, @@ -55,7 +54,7 @@ pub enum Error { }, /// Error when path cannot be canonicalized - #[snafu(display("Failed to canonicalize path \"{}\": {}", path.display(), source))] + #[error("Failed to canonicalize path \"{}\": {}", path.display(), source)] Canonicalize { /// The source path path: std::path::PathBuf, @@ -64,14 +63,14 @@ pub enum Error { }, /// Error when the path is not a valid URL - #[snafu(display("Unable to convert path \"{}\" to URL", path.display()))] + #[error("Unable to convert path \"{}\" to URL", path.display())] InvalidPath { /// The source path path: std::path::PathBuf, }, /// Error when a path contains non-unicode characters - #[snafu(display("Path \"{}\" contained non-unicode characters: {}", path, source))] + #[error("Path \"{}\" contained non-unicode characters: {}", path, source)] NonUnicode { /// The source path path: String, @@ -80,7 +79,7 @@ pub enum Error { }, /// Error when the a path doesn't start with given prefix - #[snafu(display("Path {} does not start with prefix {}", path, prefix))] + #[error("Path {} does not start with prefix {}", path, prefix)] PrefixMismatch { /// The source path path: String, @@ -173,8 +172,14 @@ impl Path { let stripped = stripped.strip_suffix(DELIMITER).unwrap_or(stripped); for segment in stripped.split(DELIMITER) { - ensure!(!segment.is_empty(), EmptySegmentSnafu { path }); - PathPart::parse(segment).context(BadSegmentSnafu { path })?; + if segment.is_empty() { + return Err(Error::EmptySegment { path: path.into() }); + } + + PathPart::parse(segment).map_err(|source| { + let path = path.into(); + Error::BadSegment { source, path } + })?; } Ok(Self { @@ -190,8 +195,9 @@ impl Path { /// /// Note: this will canonicalize the provided path, resolving any symlinks pub fn from_filesystem_path(path: impl AsRef) -> Result { - let absolute = std::fs::canonicalize(&path).context(CanonicalizeSnafu { - path: path.as_ref(), + let absolute = std::fs::canonicalize(&path).map_err(|source| { + let path = path.as_ref().into(); + Error::Canonicalize { source, path } })?; Self::from_absolute_path(absolute) @@ -241,7 +247,10 @@ impl Path { let path = path.as_ref(); let decoded = percent_decode(path.as_bytes()) .decode_utf8() - .context(NonUnicodeSnafu { path })?; + .map_err(|source| { + let path = path.into(); + Error::NonUnicode { source, path } + })?; Self::parse(decoded) } diff --git a/object_store/src/path/parts.rs b/object_store/src/path/parts.rs index df7097cbe9db..9c6612bf9331 100644 --- a/object_store/src/path/parts.rs +++ b/object_store/src/path/parts.rs @@ -19,15 +19,14 @@ use percent_encoding::{percent_encode, AsciiSet, CONTROLS}; use std::borrow::Cow; use crate::path::DELIMITER_BYTE; -use snafu::Snafu; /// Error returned by [`PathPart::parse`] -#[derive(Debug, Snafu)] -#[snafu(display( +#[derive(Debug, thiserror::Error)] +#[error( "Encountered illegal character sequence \"{}\" whilst parsing path segment \"{}\"", illegal, segment -))] +)] #[allow(missing_copy_implementations)] pub struct InvalidPart { segment: String, @@ -126,7 +125,7 @@ impl From for PathPart<'static> { } } -impl<'a> AsRef for PathPart<'a> { +impl AsRef for PathPart<'_> { fn as_ref(&self) -> &str { self.raw.as_ref() } diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index 3c2dd937acbd..a0b67ca4b58e 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -26,10 +26,6 @@ use crate::{ PutOptions, PutPayload, PutResult, Result, }; -#[doc(hidden)] -#[deprecated(note = "Use PrefixStore")] -pub type PrefixObjectStore = PrefixStore; - /// Store wrapper that applies a constant prefix to all paths handled by the store. #[derive(Debug, Clone)] pub struct PrefixStore { diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index 58f7ced5312e..29cd32705ccc 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -311,8 +311,10 @@ fn usize_to_u32_saturate(x: usize) -> u32 { } fn throttle_get(result: GetResult, wait_get_per_byte: Duration) -> GetResult { + #[allow(clippy::infallible_destructuring_match)] let s = match result.payload { GetResultPayload::Stream(s) => s, + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] GetResultPayload::File(_, _) => unimplemented!(), }; diff --git a/object_store/src/util.rs b/object_store/src/util.rs index ecf90f95d7c7..6d638f3cb2b8 100644 --- a/object_store/src/util.rs +++ b/object_store/src/util.rs @@ -24,7 +24,6 @@ use std::{ use super::Result; use bytes::Bytes; use futures::{stream::StreamExt, Stream, TryStreamExt}; -use snafu::Snafu; #[cfg(any(feature = "azure", feature = "http"))] pub(crate) static RFC1123_FMT: &str = "%a, %d %h %Y %T GMT"; @@ -75,7 +74,7 @@ where } } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] /// Takes a function and spawns it to a tokio blocking pool if available pub(crate) async fn maybe_spawn_blocking(f: F) -> Result where @@ -204,14 +203,12 @@ pub enum GetRange { Suffix(usize), } -#[derive(Debug, Snafu)] +#[derive(Debug, thiserror::Error)] pub(crate) enum InvalidGetRange { - #[snafu(display( - "Wanted range starting at {requested}, but object was only {length} bytes long" - ))] + #[error("Wanted range starting at {requested}, but object was only {length} bytes long")] StartTooLarge { requested: usize, length: usize }, - #[snafu(display("Range started at {start} and ended at {end}"))] + #[error("Range started at {start} and ended at {end}")] Inconsistent { start: usize, end: usize }, } diff --git a/object_store/tests/http.rs b/object_store/tests/http.rs new file mode 100644 index 000000000000..a9b3145bb660 --- /dev/null +++ b/object_store/tests/http.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests the HTTP store implementation + +#[cfg(feature = "http")] +use object_store::{http::HttpBuilder, path::Path, GetOptions, GetRange, ObjectStore}; + +/// Tests that even when reqwest has the `gzip` feature enabled, the HTTP store +/// does not error on a missing `Content-Length` header. +#[tokio::test] +#[cfg(feature = "http")] +async fn test_http_store_gzip() { + let http_store = HttpBuilder::new() + .with_url("https://raw.githubusercontent.com/apache/arrow-rs/refs/heads/main") + .build() + .unwrap(); + + let _ = http_store + .get_opts( + &Path::parse("LICENSE.txt").unwrap(), + GetOptions { + range: Some(GetRange::Bounded(0..100)), + ..Default::default() + }, + ) + .await + .unwrap(); +} diff --git a/parquet-testing b/parquet-testing index 550368ca77b9..4439a223a315 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 550368ca77b97231efead39251a96bd6f8f08c6e +Subproject commit 4439a223a315cf874746d3b5da25e6a6b2a2b16e diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 4064baba0947..19f890710778 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -67,7 +67,7 @@ hashbrown = { version = "0.15", default-features = false } twox-hash = { version = "1.6", default-features = false } paste = { version = "1.0" } half = { version = "2.1", default-features = false, features = ["num-traits"] } -sysinfo = { version = "0.32.0", optional = true, default-features = false, features = ["system"] } +sysinfo = { version = "0.33.0", optional = true, default-features = false, features = ["system"] } crc32fast = { version = "1.4.2", optional = true, default-features = false } [dev-dependencies] diff --git a/parquet/LICENSE.txt b/parquet/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/parquet/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/parquet/NOTICE.txt b/parquet/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/parquet/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file diff --git a/parquet/README.md b/parquet/README.md index a0441ee6026d..9ff1d921d692 100644 --- a/parquet/README.md +++ b/parquet/README.md @@ -36,7 +36,7 @@ This crate is tested with the latest stable version of Rust. We do not currently The `parquet` crate follows the [SemVer standard] defined by Cargo and works well within the Rust crate ecosystem. See the [repository README] for more details on -the release schedule and version. +the release schedule, version and deprecation policy. [semver standard]: https://doc.rust-lang.org/cargo/reference/semver.html [repository readme]: https://github.com/apache/arrow-rs @@ -59,7 +59,7 @@ The `parquet` crate provides the following features which may be enabled in your - `lz4` (default) - support for parquet using `lz4` compression - `zstd` (default) - support for parquet using `zstd` compression - `snap` (default) - support for parquet using `snappy` compression -- `cli` - parquet [CLI tools](https://github.com/apache/arrow-rs/tree/master/parquet/src/bin) +- `cli` - parquet [CLI tools](https://github.com/apache/arrow-rs/tree/main/parquet/src/bin) - `crc` - enables functionality to automatically verify checksums of each page (if present) when decoding - `experimental` - Experimental APIs which may change, even between minor releases diff --git a/parquet/benches/arrow_reader.rs b/parquet/benches/arrow_reader.rs index c424d000694a..e5165fee212c 100644 --- a/parquet/benches/arrow_reader.rs +++ b/parquet/benches/arrow_reader.rs @@ -680,7 +680,7 @@ fn create_string_list_reader( column_desc: ColumnDescPtr, ) -> Box { let items = create_byte_array_reader(page_iterator, column_desc); - let field = Field::new("item", DataType::Utf8, true); + let field = Field::new_list_field(DataType::Utf8, true); let data_type = DataType::List(Arc::new(field)); Box::new(ListArrayReader::::new(items, data_type, 2, 1, true)) } diff --git a/parquet/benches/arrow_writer.rs b/parquet/benches/arrow_writer.rs index cf39ee66f31a..bfa333db722c 100644 --- a/parquet/benches/arrow_writer.rs +++ b/parquet/benches/arrow_writer.rs @@ -189,17 +189,17 @@ fn create_list_primitive_bench_batch( let fields = vec![ Field::new( "_1", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, ), Field::new( "_2", - DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Boolean, true))), true, ), Field::new( "_3", - DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))), + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Utf8, true))), true, ), ]; @@ -220,17 +220,17 @@ fn create_list_primitive_bench_batch_non_null( let fields = vec![ Field::new( "_1", - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), false, ), Field::new( "_2", - DataType::List(Arc::new(Field::new("item", DataType::Boolean, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Boolean, false))), false, ), Field::new( "_3", - DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, false))), + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Utf8, false))), false, ), ]; @@ -274,10 +274,8 @@ fn _create_nested_bench_batch( ), Field::new( "_2", - DataType::LargeList(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new( - "item", + DataType::LargeList(Arc::new(Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field( DataType::Struct(Fields::from(vec![ Field::new( "_1", diff --git a/parquet/examples/async_read_parquet.rs b/parquet/examples/async_read_parquet.rs index e59cad8055cb..0a2e9ba994dd 100644 --- a/parquet/examples/async_read_parquet.rs +++ b/parquet/examples/async_read_parquet.rs @@ -45,7 +45,7 @@ async fn main() -> Result<()> { builder = builder.with_projection(mask); // Highlight: set `RowFilter`, it'll push down filter predicates to skip IO and decode. - // For more specific usage: please refer to https://github.com/apache/arrow-datafusion/blob/master/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs. + // For more specific usage: please refer to https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs. let scalar = Int32Array::from(vec![1]); let filter = ArrowPredicateFn::new( ProjectionMask::roots(file_metadata.schema_descr(), [0]), diff --git a/parquet/examples/write_parquet.rs b/parquet/examples/write_parquet.rs index 1b51d40c8134..ebdd9527b6f1 100644 --- a/parquet/examples/write_parquet.rs +++ b/parquet/examples/write_parquet.rs @@ -28,7 +28,7 @@ use parquet::arrow::ArrowWriter as ParquetWriter; use parquet::basic::Encoding; use parquet::errors::Result; use parquet::file::properties::{BloomFilterPosition, WriterProperties}; -use sysinfo::{MemoryRefreshKind, ProcessRefreshKind, ProcessesToUpdate, RefreshKind, System}; +use sysinfo::{ProcessRefreshKind, ProcessesToUpdate, RefreshKind, System}; #[derive(ValueEnum, Clone)] enum BloomFilterPositionArg { @@ -97,8 +97,7 @@ fn main() -> Result<()> { let file = File::create(args.path).unwrap(); let mut writer = ParquetWriter::try_new(file, schema.clone(), Some(properties))?; - let mut system = - System::new_with_specifics(RefreshKind::new().with_memory(MemoryRefreshKind::everything())); + let mut system = System::new_with_specifics(RefreshKind::everything()); eprintln!( "{} Writing {} batches of {} rows. RSS = {}", now(), diff --git a/parquet/src/arrow/array_reader/byte_view_array.rs b/parquet/src/arrow/array_reader/byte_view_array.rs index 5845e2c08cec..92a8b0592d0d 100644 --- a/parquet/src/arrow/array_reader/byte_view_array.rs +++ b/parquet/src/arrow/array_reader/byte_view_array.rs @@ -316,9 +316,8 @@ impl ByteViewArrayDecoderPlain { } pub fn read(&mut self, output: &mut ViewBuffer, len: usize) -> Result { - // Here we convert `bytes::Bytes` into `arrow_buffer::Bytes`, which is zero copy - // Then we convert `arrow_buffer::Bytes` into `arrow_buffer:Buffer`, which is also zero copy - let buf = arrow_buffer::Buffer::from_bytes(self.buf.clone().into()); + // Zero copy convert `bytes::Bytes` into `arrow_buffer::Buffer` + let buf = arrow_buffer::Buffer::from(self.buf.clone()); let block_id = output.append_block(buf); let to_read = len.min(self.max_remaining_values); @@ -549,9 +548,8 @@ impl ByteViewArrayDecoderDeltaLength { let src_lengths = &self.lengths[self.length_offset..self.length_offset + to_read]; - // Here we convert `bytes::Bytes` into `arrow_buffer::Bytes`, which is zero copy - // Then we convert `arrow_buffer::Bytes` into `arrow_buffer:Buffer`, which is also zero copy - let bytes = arrow_buffer::Buffer::from_bytes(self.data.clone().into()); + // Zero copy convert `bytes::Bytes` into `arrow_buffer::Buffer` + let bytes = Buffer::from(self.data.clone()); let block_id = output.append_block(bytes); let mut current_offset = self.data_offset; diff --git a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs index 4be07ed68f1d..6b437be943d4 100644 --- a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs +++ b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs @@ -508,8 +508,7 @@ mod tests { ); // [[], [1], [2, 3], null, [4], null, [6, 7, 8]] - let data = ArrayDataBuilder::new(ArrowType::List(Arc::new(Field::new( - "item", + let data = ArrayDataBuilder::new(ArrowType::List(Arc::new(Field::new_list_field( decimals.data_type().clone(), false, )))) diff --git a/parquet/src/arrow/array_reader/fixed_size_list_array.rs b/parquet/src/arrow/array_reader/fixed_size_list_array.rs index 75099d018fc9..43a9037d4a74 100644 --- a/parquet/src/arrow/array_reader/fixed_size_list_array.rs +++ b/parquet/src/arrow/array_reader/fixed_size_list_array.rs @@ -277,7 +277,7 @@ mod tests { let mut list_array_reader = FixedSizeListArrayReader::new( Box::new(item_array_reader), 3, - ArrowType::FixedSizeList(Arc::new(Field::new("item", ArrowType::Int32, true)), 3), + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(ArrowType::Int32, true)), 3), 2, 1, true, @@ -323,7 +323,7 @@ mod tests { let mut list_array_reader = FixedSizeListArrayReader::new( Box::new(item_array_reader), 2, - ArrowType::FixedSizeList(Arc::new(Field::new("item", ArrowType::Int32, true)), 2), + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(ArrowType::Int32, true)), 2), 1, 1, false, @@ -347,9 +347,9 @@ mod tests { // [[null, null]], // ] let l2_type = - ArrowType::FixedSizeList(Arc::new(Field::new("item", ArrowType::Int32, true)), 2); + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(ArrowType::Int32, true)), 2); let l1_type = - ArrowType::FixedSizeList(Arc::new(Field::new("item", l2_type.clone(), false)), 1); + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(l2_type.clone(), false)), 1); let array = PrimitiveArray::::from(vec![ None, @@ -436,7 +436,7 @@ mod tests { let mut list_array_reader = FixedSizeListArrayReader::new( Box::new(item_array_reader), 0, - ArrowType::FixedSizeList(Arc::new(Field::new("item", ArrowType::Int32, true)), 0), + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(ArrowType::Int32, true)), 0), 2, 1, true, @@ -481,9 +481,9 @@ mod tests { None, ])); - let inner_type = ArrowType::List(Arc::new(Field::new("item", ArrowType::Int32, true))); + let inner_type = ArrowType::List(Arc::new(Field::new_list_field(ArrowType::Int32, true))); let list_type = - ArrowType::FixedSizeList(Arc::new(Field::new("item", inner_type.clone(), true)), 2); + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(inner_type.clone(), true)), 2); let item_array_reader = InMemoryArrayReader::new( ArrowType::Int32, @@ -534,7 +534,10 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new( "list", - ArrowType::FixedSizeList(Arc::new(Field::new("item", ArrowType::Int32, true)), 4), + ArrowType::FixedSizeList( + Arc::new(Field::new_list_field(ArrowType::Int32, true)), + 4, + ), true, ), Field::new("primitive", ArrowType::Int32, true), @@ -599,7 +602,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new( "list", - ArrowType::FixedSizeList(Arc::new(Field::new("item", ArrowType::Int32, true)), 4), + ArrowType::FixedSizeList(Arc::new(Field::new_list_field(ArrowType::Int32, true)), 4), true, )])); diff --git a/parquet/src/arrow/array_reader/list_array.rs b/parquet/src/arrow/array_reader/list_array.rs index ebff3286bed5..6e583ed00c19 100644 --- a/parquet/src/arrow/array_reader/list_array.rs +++ b/parquet/src/arrow/array_reader/list_array.rs @@ -265,7 +265,7 @@ mod tests { data_type: ArrowType, item_nullable: bool, ) -> ArrowType { - let field = Arc::new(Field::new("item", data_type, item_nullable)); + let field = Arc::new(Field::new_list_field(data_type, item_nullable)); GenericListArray::::DATA_TYPE_CONSTRUCTOR(field) } diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index 010e9c2eed3f..a952e00e12ef 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -208,10 +208,10 @@ where // As there is not always a 1:1 mapping between Arrow and Parquet, there // are datatypes which we must convert explicitly. // These are: - // - date64: we should cast int32 to date32, then date32 to date64. - // - decimal: cast in32 to decimal, int64 to decimal + // - date64: cast int32 to date32, then date32 to date64. + // - decimal: cast int32 to decimal, int64 to decimal let array = match target_type { - ArrowType::Date64 => { + ArrowType::Date64 if *(array.data_type()) == ArrowType::Int32 => { // this is cheap as it internally reinterprets the data let a = arrow_cast::cast(&array, &ArrowType::Date32)?; arrow_cast::cast(&a, target_type)? @@ -305,9 +305,9 @@ mod tests { use crate::util::test_common::rand_gen::make_pages; use crate::util::InMemoryPageIterator; use arrow::datatypes::ArrowPrimitiveType; - use arrow_array::{Array, PrimitiveArray}; + use arrow_array::{Array, Date32Array, PrimitiveArray}; - use arrow::datatypes::DataType::Decimal128; + use arrow::datatypes::DataType::{Date32, Decimal128}; use rand::distributions::uniform::SampleUniform; use std::collections::VecDeque; @@ -783,4 +783,54 @@ mod tests { assert_ne!(array, &data_decimal_array) } } + + #[test] + fn test_primitive_array_reader_date32_type() { + // parquet `INT32` to date + let message_type = " + message test_schema { + REQUIRED INT32 date1 (DATE); + } + "; + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + let column_desc = schema.column(0); + + // create the array reader + { + let mut data = Vec::new(); + let mut page_lists = Vec::new(); + make_column_chunks::( + column_desc.clone(), + Encoding::PLAIN, + 100, + -99999999, + 99999999, + &mut Vec::new(), + &mut Vec::new(), + &mut data, + &mut page_lists, + true, + 2, + ); + let page_iterator = InMemoryPageIterator::new(page_lists); + + let mut array_reader = + PrimitiveArrayReader::::new(Box::new(page_iterator), column_desc, None) + .unwrap(); + + // read data from the reader + // the data type is date + let array = array_reader.next_batch(50).unwrap(); + assert_eq!(array.data_type(), &Date32); + let array = array.as_any().downcast_ref::().unwrap(); + let data_date_array = data[0..50] + .iter() + .copied() + .map(Some) + .collect::(); + assert_eq!(array, &data_date_array); + } + } } diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index d3709c03e99a..6eba04c86f91 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -932,12 +932,12 @@ mod tests { use arrow_array::builder::*; use arrow_array::cast::AsArray; use arrow_array::types::{ - Decimal128Type, Decimal256Type, DecimalType, Float16Type, Float32Type, Float64Type, - Time32MillisecondType, Time64MicrosecondType, + Date32Type, Date64Type, Decimal128Type, Decimal256Type, DecimalType, Float16Type, + Float32Type, Float64Type, Time32MillisecondType, Time64MicrosecondType, }; use arrow_array::*; use arrow_buffer::{i256, ArrowNativeType, Buffer, IntervalDayTime}; - use arrow_data::ArrayDataBuilder; + use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ ArrowError, DataType as ArrowDataType, Field, Fields, Schema, SchemaRef, TimeUnit, }; @@ -989,6 +989,21 @@ mod tests { assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]); } + #[test] + fn test_arrow_reader_single_column_by_name() { + let file = get_test_file("parquet/generated_simple_numerics/blogs.parquet"); + + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let original_schema = Arc::clone(builder.schema()); + + let mask = ProjectionMask::columns(builder.parquet_schema(), ["blog_id"]); + let reader = builder.with_projection(mask).build().unwrap(); + + // Verify that the schema was correctly parsed + assert_eq!(1, reader.schema().fields().len()); + assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]); + } + #[test] fn test_null_column_reader_test() { let mut file = tempfile::tempfile().unwrap(); @@ -1272,6 +1287,117 @@ mod tests { Ok(()) } + #[test] + fn test_date32_roundtrip() -> Result<()> { + use arrow_array::Date32Array; + + let schema = Arc::new(Schema::new(vec![Field::new( + "date32", + ArrowDataType::Date32, + false, + )])); + + let mut buf = Vec::with_capacity(1024); + + let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?; + + let original = RecordBatch::try_new( + schema, + vec![Arc::new(Date32Array::from(vec![ + -1_000_000, -100_000, -10_000, -1_000, 0, 1_000, 10_000, 100_000, 1_000_000, + ]))], + )?; + + writer.write(&original)?; + writer.close()?; + + let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024)?; + let ret = reader.next().unwrap()?; + assert_eq!(ret, original); + + // Ensure can be downcast to the correct type + ret.column(0).as_primitive::(); + + Ok(()) + } + + #[test] + fn test_date64_roundtrip() -> Result<()> { + use arrow_array::Date64Array; + + let schema = Arc::new(Schema::new(vec![ + Field::new("small-date64", ArrowDataType::Date64, false), + Field::new("big-date64", ArrowDataType::Date64, false), + Field::new("invalid-date64", ArrowDataType::Date64, false), + ])); + + let mut default_buf = Vec::with_capacity(1024); + let mut coerce_buf = Vec::with_capacity(1024); + + let coerce_props = WriterProperties::builder().set_coerce_types(true).build(); + + let mut default_writer = ArrowWriter::try_new(&mut default_buf, schema.clone(), None)?; + let mut coerce_writer = + ArrowWriter::try_new(&mut coerce_buf, schema.clone(), Some(coerce_props))?; + + static NUM_MILLISECONDS_IN_DAY: i64 = 1000 * 60 * 60 * 24; + + let original = RecordBatch::try_new( + schema, + vec![ + // small-date64 + Arc::new(Date64Array::from(vec![ + -1_000_000 * NUM_MILLISECONDS_IN_DAY, + -1_000 * NUM_MILLISECONDS_IN_DAY, + 0, + 1_000 * NUM_MILLISECONDS_IN_DAY, + 1_000_000 * NUM_MILLISECONDS_IN_DAY, + ])), + // big-date64 + Arc::new(Date64Array::from(vec![ + -10_000_000_000 * NUM_MILLISECONDS_IN_DAY, + -1_000_000_000 * NUM_MILLISECONDS_IN_DAY, + 0, + 1_000_000_000 * NUM_MILLISECONDS_IN_DAY, + 10_000_000_000 * NUM_MILLISECONDS_IN_DAY, + ])), + // invalid-date64 + Arc::new(Date64Array::from(vec![ + -1_000_000 * NUM_MILLISECONDS_IN_DAY + 1, + -1_000 * NUM_MILLISECONDS_IN_DAY + 1, + 1, + 1_000 * NUM_MILLISECONDS_IN_DAY + 1, + 1_000_000 * NUM_MILLISECONDS_IN_DAY + 1, + ])), + ], + )?; + + default_writer.write(&original)?; + coerce_writer.write(&original)?; + + default_writer.close()?; + coerce_writer.close()?; + + let mut default_reader = ParquetRecordBatchReader::try_new(Bytes::from(default_buf), 1024)?; + let mut coerce_reader = ParquetRecordBatchReader::try_new(Bytes::from(coerce_buf), 1024)?; + + let default_ret = default_reader.next().unwrap()?; + let coerce_ret = coerce_reader.next().unwrap()?; + + // Roundtrip should be successful when default writer used + assert_eq!(default_ret, original); + + // Only small-date64 should roundtrip successfully when coerce_types writer is used + assert_eq!(coerce_ret.column(0), original.column(0)); + assert_ne!(coerce_ret.column(1), original.column(1)); + assert_ne!(coerce_ret.column(2), original.column(2)); + + // Ensure both can be downcast to the correct type + default_ret.column(0).as_primitive::(); + coerce_ret.column(0).as_primitive::(); + + Ok(()) + } struct RandFixedLenGen {} impl RandGen for RandFixedLenGen { @@ -1542,8 +1668,7 @@ mod tests { let decimals = Decimal128Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]); // [[], [1], [2, 3], null, [4], null, [6, 7, 8]] - let data = ArrayDataBuilder::new(ArrowDataType::List(Arc::new(Field::new( - "item", + let data = ArrayDataBuilder::new(ArrowDataType::List(Arc::new(Field::new_list_field( decimals.data_type().clone(), false, )))) @@ -2453,6 +2578,59 @@ mod tests { } } + #[test] + // same as test_read_structs but constructs projection mask via column names + fn test_read_structs_by_name() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/nested_structs.rust.parquet"); + let file = File::open(&path).unwrap(); + let record_batch_reader = ParquetRecordBatchReader::try_new(file, 60).unwrap(); + + for batch in record_batch_reader { + batch.unwrap(); + } + + let file = File::open(&path).unwrap(); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + + let mask = ProjectionMask::columns( + builder.parquet_schema(), + ["roll_num.count", "PC_CUR.mean", "PC_CUR.sum"], + ); + let projected_reader = builder + .with_projection(mask) + .with_batch_size(60) + .build() + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new( + "roll_num", + ArrowDataType::Struct(Fields::from(vec![Field::new( + "count", + ArrowDataType::UInt64, + false, + )])), + false, + ), + Field::new( + "PC_CUR", + ArrowDataType::Struct(Fields::from(vec![ + Field::new("mean", ArrowDataType::Int64, false), + Field::new("sum", ArrowDataType::Int64, false), + ])), + false, + ), + ]); + + assert_eq!(&expected_schema, projected_reader.schema().as_ref()); + + for batch in projected_reader { + let batch = batch.unwrap(); + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + } + #[test] fn test_read_maps() { let testdata = arrow::util::test_util::parquet_test_data(); @@ -2874,7 +3052,7 @@ mod tests { let arrow_field = Field::new( "emptylist", - ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Null, true))), + ArrowDataType::List(Arc::new(Field::new_list_field(ArrowDataType::Null, true))), true, ); @@ -3346,7 +3524,7 @@ mod tests { fn test_row_group_batch(row_group_size: usize, batch_size: usize) { let schema = Arc::new(Schema::new(vec![Field::new( "list", - ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int32, true))), + ArrowDataType::List(Arc::new(Field::new_list_field(ArrowDataType::Int32, true))), true, )])); @@ -3584,9 +3762,7 @@ mod tests { .unwrap(); // Although `Vec>` of each row group is empty, // we should read the file successfully. - // FIXME: this test will fail when metadata parsing returns `None` for missing page - // indexes. https://github.com/apache/arrow-rs/issues/6447 - assert!(builder.metadata().offset_index().unwrap()[0].is_empty()); + assert!(builder.metadata().offset_index().is_none()); let reader = builder.build().unwrap(); let batches = reader.collect::, _>>().unwrap(); assert_eq!(batches.len(), 1); @@ -3905,7 +4081,7 @@ mod tests { fn test_list_selection() { let schema = Arc::new(Schema::new(vec![Field::new_list( "list", - Field::new("item", ArrowDataType::Utf8, true), + Field::new_list_field(ArrowDataType::Utf8, true), false, )])); let mut buf = Vec::with_capacity(1024); @@ -3961,7 +4137,11 @@ mod tests { let mut rng = thread_rng(); let schema = Arc::new(Schema::new(vec![Field::new_list( "list", - Field::new_list("item", Field::new("item", ArrowDataType::Int32, true), true), + Field::new_list( + Field::LIST_FIELD_DEFAULT_NAME, + Field::new_list_field(ArrowDataType::Int32, true), + true, + ), true, )])); let mut buf = Vec::with_capacity(1024); @@ -4065,4 +4245,93 @@ mod tests { } } } + + #[test] + fn test_read_old_nested_list() { + use arrow::datatypes::DataType; + use arrow::datatypes::ToByteSlice; + + let testdata = arrow::util::test_util::parquet_test_data(); + // message my_record { + // REQUIRED group a (LIST) { + // REPEATED group array (LIST) { + // REPEATED INT32 array; + // } + // } + // } + // should be read as list> + let path = format!("{testdata}/old_list_structure.parquet"); + let test_file = File::open(path).unwrap(); + + // create expected ListArray + let a_values = Int32Array::from(vec![1, 2, 3, 4]); + + // Construct a buffer for value offsets, for the nested array: [[1, 2], [3, 4]] + let a_value_offsets = arrow::buffer::Buffer::from([0, 2, 4].to_byte_slice()); + + // Construct a list array from the above two + let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new( + "array", + DataType::Int32, + false, + )))) + .len(2) + .add_buffer(a_value_offsets) + .add_child_data(a_values.into_data()) + .build() + .unwrap(); + let a = ListArray::from(a_list_data); + + let builder = ParquetRecordBatchReaderBuilder::try_new(test_file).unwrap(); + let mut reader = builder.build().unwrap(); + let out = reader.next().unwrap().unwrap(); + assert_eq!(out.num_rows(), 1); + assert_eq!(out.num_columns(), 1); + // grab first column + let c0 = out.column(0); + let c0arr = c0.as_any().downcast_ref::().unwrap(); + // get first row: [[1, 2], [3, 4]] + let r0 = c0arr.value(0); + let r0arr = r0.as_any().downcast_ref::().unwrap(); + assert_eq!(r0arr, &a); + } + + #[test] + fn test_map_no_value() { + // File schema: + // message schema { + // required group my_map (MAP) { + // repeated group key_value { + // required int32 key; + // optional int32 value; + // } + // } + // required group my_map_no_v (MAP) { + // repeated group key_value { + // required int32 key; + // } + // } + // required group my_list (LIST) { + // repeated group list { + // required int32 element; + // } + // } + // } + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/map_no_value.parquet"); + let file = File::open(path).unwrap(); + + let mut reader = ParquetRecordBatchReaderBuilder::try_new(file) + .unwrap() + .build() + .unwrap(); + let out = reader.next().unwrap().unwrap(); + assert_eq!(out.num_rows(), 3); + assert_eq!(out.num_columns(), 3); + // my_map_no_v and my_list columns should now be equivalent + let c0 = out.column(1).as_list::(); + let c1 = out.column(2).as_list::(); + assert_eq!(c0.len(), c1.len()); + c0.iter().zip(c1.iter()).for_each(|(l, r)| assert_eq!(l, r)); + } } diff --git a/parquet/src/arrow/arrow_reader/statistics.rs b/parquet/src/arrow/arrow_reader/statistics.rs index 8a7511be2afe..09f8ec7cc274 100644 --- a/parquet/src/arrow/arrow_reader/statistics.rs +++ b/parquet/src/arrow/arrow_reader/statistics.rs @@ -21,6 +21,7 @@ /// `arrow-rs/parquet/tests/arrow_reader/statistics.rs`. use crate::arrow::buffer::bit_util::sign_extend_be; use crate::arrow::parquet_column; +use crate::basic::Type as PhysicalType; use crate::data_type::{ByteArray, FixedLenByteArray}; use crate::errors::{ParquetError, Result}; use crate::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex, RowGroupMetaData}; @@ -318,7 +319,7 @@ make_decimal_stats_iterator!( /// data_type: The data type of the statistics (e.g. `DataType::Int32`) /// iterator: The iterator of [`ParquetStatistics`] to extract the statistics from. macro_rules! get_statistics { - ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { + ($stat_type_prefix: ident, $data_type: ident, $iterator: ident, $physical_type: ident) => { paste! { match $data_type { DataType::Boolean => Ok(Arc::new(BooleanArray::from_iter( @@ -370,10 +371,11 @@ macro_rules! get_statistics { DataType::Date32 => Ok(Arc::new(Date32Array::from_iter( [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| x.copied()), ))), - DataType::Date64 => Ok(Arc::new(Date64Array::from_iter( + DataType::Date64 if $physical_type == Some(PhysicalType::INT32) => Ok(Arc::new(Date64Array::from_iter( [<$stat_type_prefix Int32StatsIterator>]::new($iterator) - .map(|x| x.map(|x| i64::from(*x) * 24 * 60 * 60 * 1000)), - ))), + .map(|x| x.map(|x| i64::from(*x) * 24 * 60 * 60 * 1000))))), + DataType::Date64 if $physical_type == Some(PhysicalType::INT64) => Ok(Arc::new(Date64Array::from_iter( + [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()),))), DataType::Timestamp(unit, timezone) =>{ let iter = [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()); Ok(match unit { @@ -487,7 +489,7 @@ macro_rules! get_statistics { Ok(Arc::new(arr)) }, DataType::Dictionary(_, value_type) => { - [<$stat_type_prefix:lower _ statistics>](value_type, $iterator) + [<$stat_type_prefix:lower _ statistics>](value_type, $iterator, $physical_type) }, DataType::Utf8View => { let iterator = [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator); @@ -524,6 +526,7 @@ macro_rules! get_statistics { DataType::Map(_,_) | DataType::Duration(_) | DataType::Interval(_) | + DataType::Date64 | // required to cover $physical_type match guard DataType::Null | DataType::List(_) | DataType::ListView(_) | @@ -790,7 +793,7 @@ get_decimal_page_stats_iterator!( ); macro_rules! get_data_page_statistics { - ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { + ($stat_type_prefix: ident, $data_type: ident, $iterator: ident, $physical_type: ident) => { paste! { match $data_type { DataType::Boolean => { @@ -929,7 +932,7 @@ macro_rules! get_data_page_statistics { Ok(Arc::new(builder.finish())) }, DataType::Dictionary(_, value_type) => { - [<$stat_type_prefix:lower _ page_statistics>](value_type, $iterator) + [<$stat_type_prefix:lower _ page_statistics>](value_type, $iterator, $physical_type) }, DataType::Timestamp(unit, timezone) => { let iter = [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(); @@ -941,7 +944,7 @@ macro_rules! get_data_page_statistics { }) }, DataType::Date32 => Ok(Arc::new(Date32Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten()))), - DataType::Date64 => Ok( + DataType::Date64 if $physical_type == Some(PhysicalType::INT32)=> Ok( Arc::new( Date64Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { @@ -954,6 +957,7 @@ macro_rules! get_data_page_statistics { ) ) ), + DataType::Date64 if $physical_type == Some(PhysicalType::INT64) => Ok(Arc::new(Date64Array::from_iter([<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten()))), DataType::Decimal128(precision, scale) => Ok(Arc::new( Decimal128Array::from_iter([<$stat_type_prefix Decimal128DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), DataType::Decimal256(precision, scale) => Ok(Arc::new( @@ -1040,6 +1044,7 @@ macro_rules! get_data_page_statistics { } Ok(Arc::new(builder.finish())) }, + DataType::Date64 | // required to cover $physical_type match guard DataType::Null | DataType::Duration(_) | DataType::Interval(_) | @@ -1067,8 +1072,9 @@ macro_rules! get_data_page_statistics { fn min_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, + physical_type: Option, ) -> Result { - get_statistics!(Min, data_type, iterator) + get_statistics!(Min, data_type, iterator, physical_type) } /// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] @@ -1077,26 +1083,35 @@ fn min_statistics<'a, I: Iterator>>( fn max_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, + physical_type: Option, ) -> Result { - get_statistics!(Max, data_type, iterator) + get_statistics!(Max, data_type, iterator, physical_type) } /// Extracts the min statistics from an iterator /// of parquet page [`Index`]'es to an [`ArrayRef`] -pub(crate) fn min_page_statistics<'a, I>(data_type: &DataType, iterator: I) -> Result +pub(crate) fn min_page_statistics<'a, I>( + data_type: &DataType, + iterator: I, + physical_type: Option, +) -> Result where I: Iterator, { - get_data_page_statistics!(Min, data_type, iterator) + get_data_page_statistics!(Min, data_type, iterator, physical_type) } /// Extracts the max statistics from an iterator /// of parquet page [`Index`]'es to an [`ArrayRef`] -pub(crate) fn max_page_statistics<'a, I>(data_type: &DataType, iterator: I) -> Result +pub(crate) fn max_page_statistics<'a, I>( + data_type: &DataType, + iterator: I, + physical_type: Option, +) -> Result where I: Iterator, { - get_data_page_statistics!(Max, data_type, iterator) + get_data_page_statistics!(Max, data_type, iterator, physical_type) } /// Extracts the null count statistics from an iterator @@ -1177,6 +1192,8 @@ pub struct StatisticsConverter<'a> { arrow_field: &'a Field, /// treat missing null_counts as 0 nulls missing_null_counts_as_zero: bool, + /// The physical type of the matched column in the Parquet schema + physical_type: Option, } impl<'a> StatisticsConverter<'a> { @@ -1304,6 +1321,7 @@ impl<'a> StatisticsConverter<'a> { parquet_column_index: parquet_index, arrow_field, missing_null_counts_as_zero: true, + physical_type: parquet_index.map(|idx| parquet_schema.column(idx).physical_type()), }) } @@ -1346,7 +1364,7 @@ impl<'a> StatisticsConverter<'a> { /// // get the minimum value for the column "foo" in the parquet file /// let min_values: ArrayRef = converter /// .row_group_mins(metadata.row_groups().iter()) - /// .unwrap(); + /// .unwrap(); /// // if "foo" is a Float64 value, the returned array will contain Float64 values /// assert_eq!(min_values, Arc::new(Float64Array::from(vec![Some(1.0), Some(2.0)])) as _); /// ``` @@ -1363,7 +1381,7 @@ impl<'a> StatisticsConverter<'a> { let iter = metadatas .into_iter() .map(|x| x.column(parquet_index).statistics()); - min_statistics(data_type, iter) + min_statistics(data_type, iter, self.physical_type) } /// Extract the maximum values from row group statistics in [`RowGroupMetaData`] @@ -1382,7 +1400,7 @@ impl<'a> StatisticsConverter<'a> { let iter = metadatas .into_iter() .map(|x| x.column(parquet_index).statistics()); - max_statistics(data_type, iter) + max_statistics(data_type, iter, self.physical_type) } /// Extract the null counts from row group statistics in [`RowGroupMetaData`] @@ -1490,7 +1508,7 @@ impl<'a> StatisticsConverter<'a> { (*num_data_pages, column_page_index_per_row_group_per_column) }); - min_page_statistics(data_type, iter) + min_page_statistics(data_type, iter, self.physical_type) } /// Extract the maximum values from Data Page statistics. @@ -1521,7 +1539,7 @@ impl<'a> StatisticsConverter<'a> { (*num_data_pages, column_page_index_per_row_group_per_column) }); - max_page_statistics(data_type, iter) + max_page_statistics(data_type, iter, self.physical_type) } /// Returns a [`UInt64Array`] with null counts for each data page. diff --git a/parquet/src/arrow/arrow_writer/levels.rs b/parquet/src/arrow/arrow_writer/levels.rs index 3e828bbddd17..e4662b8f316c 100644 --- a/parquet/src/arrow/arrow_writer/levels.rs +++ b/parquet/src/arrow/arrow_writer/levels.rs @@ -632,7 +632,7 @@ mod tests { // based on the example at https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet.html // [[a, b, c], [d, e, f, g]], [[h], [i,j]] - let leaf_type = Field::new("item", DataType::Int32, false); + let leaf_type = Field::new_list_field(DataType::Int32, false); let inner_type = DataType::List(Arc::new(leaf_type)); let inner_field = Field::new("l2", inner_type.clone(), false); let outer_type = DataType::List(Arc::new(inner_field)); @@ -676,7 +676,7 @@ mod tests { fn test_calculate_one_level_1() { // This test calculates the levels for a non-null primitive array let array = Arc::new(Int32Array::from_iter(0..10)) as ArrayRef; - let field = Field::new("item", DataType::Int32, false); + let field = Field::new_list_field(DataType::Int32, false); let levels = calculate_array_levels(&array, &field).unwrap(); assert_eq!(levels.len(), 1); @@ -702,7 +702,7 @@ mod tests { Some(0), None, ])) as ArrayRef; - let field = Field::new("item", DataType::Int32, true); + let field = Field::new_list_field(DataType::Int32, true); let levels = calculate_array_levels(&array, &field).unwrap(); assert_eq!(levels.len(), 1); @@ -720,7 +720,7 @@ mod tests { #[test] fn test_calculate_array_levels_1() { - let leaf_field = Field::new("item", DataType::Int32, false); + let leaf_field = Field::new_list_field(DataType::Int32, false); let list_type = DataType::List(Arc::new(leaf_field)); // if all array values are defined (e.g. batch>) @@ -1046,7 +1046,7 @@ mod tests { let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); let a_value_offsets = arrow::buffer::Buffer::from_iter([0_i32, 1, 3, 3, 6, 10]); - let a_list_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let a_list_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let a_list_data = ArrayData::builder(a_list_type.clone()) .len(5) .add_buffer(a_value_offsets) @@ -1059,7 +1059,7 @@ mod tests { let a = ListArray::from(a_list_data); - let item_field = Field::new("item", a_list_type, true); + let item_field = Field::new_list_field(a_list_type, true); let mut builder = levels(&item_field, a); builder.write(2..4); let levels = builder.finish(); @@ -1334,7 +1334,7 @@ mod tests { // define schema let int_field = Field::new("a", DataType::Int32, true); let fields = Fields::from([Arc::new(int_field)]); - let item_field = Field::new("item", DataType::Struct(fields.clone()), true); + let item_field = Field::new_list_field(DataType::Struct(fields.clone()), true); let list_field = Field::new("list", DataType::List(Arc::new(item_field)), true); let int_builder = Int32Builder::with_capacity(10); @@ -1568,7 +1568,7 @@ mod tests { let a = builder.finish(); let values = a.values().clone(); - let item_field = Field::new("item", a.data_type().clone(), true); + let item_field = Field::new_list_field(a.data_type().clone(), true); let mut builder = levels(&item_field, a); builder.write(1..4); let levels = builder.finish(); @@ -1594,7 +1594,7 @@ mod tests { let field_a = Field::new("a", DataType::Int32, true); let field_b = Field::new("b", DataType::Int64, false); let fields = Fields::from([Arc::new(field_a), Arc::new(field_b)]); - let item_field = Field::new("item", DataType::Struct(fields.clone()), true); + let item_field = Field::new_list_field(DataType::Struct(fields.clone()), true); let list_field = Field::new( "list", DataType::FixedSizeList(Arc::new(item_field), 2), @@ -1758,7 +1758,7 @@ mod tests { let array = builder.finish(); let values = array.values().clone(); - let item_field = Field::new("item", array.data_type().clone(), true); + let item_field = Field::new_list_field(array.data_type().clone(), true); let mut builder = levels(&item_field, array); builder.write(0..3); let levels = builder.finish(); @@ -1797,7 +1797,7 @@ mod tests { let a = builder.finish(); let values = a.values().as_list::().values().clone(); - let item_field = Field::new("item", a.data_type().clone(), true); + let item_field = Field::new_list_field(a.data_type().clone(), true); let mut builder = levels(&item_field, a); builder.write(0..4); let levels = builder.finish(); @@ -1827,7 +1827,7 @@ mod tests { // [NULL, NULL, 3, 0] let dict = DictionaryArray::new(keys, Arc::new(values)); - let item_field = Field::new("item", dict.data_type().clone(), true); + let item_field = Field::new_list_field(dict.data_type().clone(), true); let mut builder = levels(&item_field, dict.clone()); builder.write(0..4); @@ -1846,7 +1846,7 @@ mod tests { #[test] fn mismatched_types() { let array = Arc::new(Int32Array::from_iter(0..10)) as ArrayRef; - let field = Field::new("item", DataType::Float64, false); + let field = Field::new_list_field(DataType::Float64, false); let err = LevelInfoBuilder::try_new(&field, Default::default(), &array) .unwrap_err() diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 99d54eef3bb5..871b140768cb 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -30,12 +30,10 @@ use arrow_array::types::*; use arrow_array::{ArrayRef, RecordBatch, RecordBatchWriter}; use arrow_schema::{ArrowError, DataType as ArrowDataType, Field, IntervalUnit, SchemaRef}; -use super::schema::{ - add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema, - arrow_to_parquet_schema_with_root, decimal_length_from_precision, -}; +use super::schema::{add_encoded_arrow_schema_to_metadata, decimal_length_from_precision}; use crate::arrow::arrow_writer::byte_array::ByteArrayEncoder; +use crate::arrow::ArrowSchemaConverter; use crate::column::page::{CompressedPage, PageWriteSpec, PageWriter}; use crate::column::writer::encoder::ColumnValueEncoder; use crate::column::writer::{ @@ -180,11 +178,12 @@ impl ArrowWriter { arrow_schema: SchemaRef, options: ArrowWriterOptions, ) -> Result { - let schema = match options.schema_root { - Some(s) => arrow_to_parquet_schema_with_root(&arrow_schema, &s)?, - None => arrow_to_parquet_schema(&arrow_schema)?, - }; let mut props = options.properties; + let mut converter = ArrowSchemaConverter::new().with_coerce_types(props.coerce_types()); + if let Some(schema_root) = &options.schema_root { + converter = converter.schema_root(schema_root); + } + let schema = converter.convert(&arrow_schema)?; if !options.skip_arrow_metadata { // add serialized arrow schema add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut props); @@ -390,9 +389,9 @@ impl ArrowWriterOptions { } /// Set the name of the root parquet schema element (defaults to `"arrow_schema"`) - pub fn with_schema_root(self, name: String) -> Self { + pub fn with_schema_root(self, schema_root: String) -> Self { Self { - schema_root: Some(name), + schema_root: Some(schema_root), ..self } } @@ -538,7 +537,7 @@ impl ArrowColumnChunk { /// # use std::sync::Arc; /// # use arrow_array::*; /// # use arrow_schema::*; -/// # use parquet::arrow::arrow_to_parquet_schema; +/// # use parquet::arrow::ArrowSchemaConverter; /// # use parquet::arrow::arrow_writer::{ArrowLeafColumn, compute_leaves, get_column_writers}; /// # use parquet::file::properties::WriterProperties; /// # use parquet::file::writer::SerializedFileWriter; @@ -549,8 +548,11 @@ impl ArrowColumnChunk { /// ])); /// /// // Compute the parquet schema -/// let parquet_schema = arrow_to_parquet_schema(schema.as_ref()).unwrap(); /// let props = Arc::new(WriterProperties::default()); +/// let parquet_schema = ArrowSchemaConverter::new() +/// .with_coerce_types(props.coerce_types()) +/// .convert(&schema) +/// .unwrap(); /// /// // Create writers for each of the leaf columns /// let col_writers = get_column_writers(&parquet_schema, &props, &schema).unwrap(); @@ -858,6 +860,12 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result { match column.data_type() { + ArrowDataType::Date64 => { + let array = arrow_cast::cast(column, &ArrowDataType::Int64)?; + + let array = array.as_primitive::(); + write_primitive(typed, array.values(), levels) + } ArrowDataType::Int64 => { let array = column.as_primitive::(); write_primitive(typed, array.values(), levels) @@ -1082,6 +1090,7 @@ mod tests { use arrow::datatypes::ToByteSlice; use arrow::datatypes::{DataType, Schema}; use arrow::error::Result as ArrowResult; + use arrow::util::data_gen::create_random_array; use arrow::util::pretty::pretty_format_batches; use arrow::{array::*, buffer::Buffer}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, NullBuffer}; @@ -1194,7 +1203,7 @@ mod tests { // define schema let schema = Schema::new(vec![Field::new( "a", - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), true, )]); @@ -1206,8 +1215,7 @@ mod tests { let a_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice()); // Construct a list array from the above two - let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new( - "item", + let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, false, )))) @@ -1234,7 +1242,7 @@ mod tests { // define schema let schema = Schema::new(vec![Field::new( "a", - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), false, )]); @@ -1246,8 +1254,7 @@ mod tests { let a_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice()); // Construct a list array from the above two - let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new( - "item", + let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, false, )))) @@ -1365,12 +1372,12 @@ mod tests { let struct_field_f = Arc::new(Field::new("f", DataType::Float32, true)); let struct_field_g = Arc::new(Field::new_list( "g", - Field::new("item", DataType::Int16, true), + Field::new_list_field(DataType::Int16, true), false, )); let struct_field_h = Arc::new(Field::new_list( "h", - Field::new("item", DataType::Int16, false), + Field::new_list_field(DataType::Int16, false), true, )); let struct_field_e = Arc::new(Field::new_struct( @@ -1743,7 +1750,7 @@ mod tests { "Expected a dictionary page" ); - let offset_indexes = read_offset_indexes(&file, column).unwrap(); + let offset_indexes = read_offset_indexes(&file, column).unwrap().unwrap(); let page_locations = offset_indexes[0].page_locations.clone(); @@ -2377,7 +2384,7 @@ mod tests { #[test] fn null_list_single_column() { - let null_field = Field::new("item", DataType::Null, true); + let null_field = Field::new_list_field(DataType::Null, true); let list_field = Field::new("emptylist", DataType::List(Arc::new(null_field)), true); let schema = Schema::new(vec![list_field]); @@ -2385,8 +2392,7 @@ mod tests { // Build [[], null, [null, null]] let a_values = NullArray::new(2); let a_value_offsets = arrow::buffer::Buffer::from([0, 0, 0, 2].to_byte_slice()); - let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new( - "item", + let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field( DataType::Null, true, )))) @@ -2415,8 +2421,7 @@ mod tests { fn list_single_column() { let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); let a_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice()); - let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new( - "item", + let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field( DataType::Int32, false, )))) @@ -2489,6 +2494,56 @@ mod tests { one_column_roundtrip(values, false); } + #[test] + fn list_and_map_coerced_names() { + // Create map and list with non-Parquet naming + let list_field = + Field::new_list("my_list", Field::new("item", DataType::Int32, false), false); + let map_field = Field::new_map( + "my_map", + "entries", + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int32, true), + false, + true, + ); + + let list_array = create_random_array(&list_field, 100, 0.0, 0.0).unwrap(); + let map_array = create_random_array(&map_field, 100, 0.0, 0.0).unwrap(); + + let arrow_schema = Arc::new(Schema::new(vec![list_field, map_field])); + + // Write data to Parquet but coerce names to match spec + let props = Some(WriterProperties::builder().set_coerce_types(true).build()); + let file = tempfile::tempfile().unwrap(); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), arrow_schema.clone(), props).unwrap(); + + let batch = RecordBatch::try_new(arrow_schema, vec![list_array, map_array]).unwrap(); + writer.write(&batch).unwrap(); + let file_metadata = writer.close().unwrap(); + + // Coerced name of "item" should be "element" + assert_eq!(file_metadata.schema[3].name, "element"); + // Coerced name of "entries" should be "key_value" + assert_eq!(file_metadata.schema[5].name, "key_value"); + // Coerced name of "keys" should be "key" + assert_eq!(file_metadata.schema[6].name, "key"); + // Coerced name of "values" should be "value" + assert_eq!(file_metadata.schema[7].name, "value"); + + // Double check schema after reading from the file + let reader = SerializedFileReader::new(file).unwrap(); + let file_schema = reader.metadata().file_metadata().schema(); + let fields = file_schema.get_fields(); + let list_field = &fields[0].get_fields()[0]; + assert_eq!(list_field.get_fields()[0].name(), "element"); + let map_field = &fields[1].get_fields()[0]; + assert_eq!(map_field.name(), "key_value"); + assert_eq!(map_field.get_fields()[0].name(), "key"); + assert_eq!(map_field.get_fields()[1].name(), "value"); + } + #[test] fn fallback_flush_data_page() { //tests if the Fallback::flush_data_page clears all buffers correctly @@ -2534,6 +2589,7 @@ mod tests { #[test] fn arrow_writer_string_dictionary() { // define schema + #[allow(deprecated)] let schema = Arc::new(Schema::new(vec![Field::new_dict( "dictionary", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), @@ -2555,6 +2611,7 @@ mod tests { #[test] fn arrow_writer_primitive_dictionary() { // define schema + #[allow(deprecated)] let schema = Arc::new(Schema::new(vec![Field::new_dict( "dictionary", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), @@ -2577,6 +2634,7 @@ mod tests { #[test] fn arrow_writer_string_dictionary_unsigned_index() { // define schema + #[allow(deprecated)] let schema = Arc::new(Schema::new(vec![Field::new_dict( "dictionary", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), diff --git a/parquet/src/arrow/async_reader/metadata.rs b/parquet/src/arrow/async_reader/metadata.rs index b19f9830a7c9..526818845b5c 100644 --- a/parquet/src/arrow/async_reader/metadata.rs +++ b/parquet/src/arrow/async_reader/metadata.rs @@ -119,7 +119,7 @@ impl MetadataLoader { return Err(ParquetError::EOF(format!( "file size of {} is less than footer + metadata {}", file_size, - length + 8 + length + FOOTER_SIZE ))); } diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 8b315cc9f784..4f3befe42662 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -158,7 +158,8 @@ pub trait AsyncFileReader: Send { fn get_metadata(&mut self) -> BoxFuture<'_, Result>>; } -impl AsyncFileReader for Box { +/// This allows Box to be used as an AsyncFileReader, +impl AsyncFileReader for Box { fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result> { self.as_mut().get_bytes(range) } @@ -612,6 +613,9 @@ impl std::fmt::Debug for StreamState { /// An asynchronous [`Stream`](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) of [`RecordBatch`] /// for a parquet file that can be constructed using [`ParquetRecordBatchStreamBuilder`]. +/// +/// `ParquetRecordBatchStream` also provides [`ParquetRecordBatchStream::next_row_group`] for fetching row groups, +/// allowing users to decode record batches separately from I/O. pub struct ParquetRecordBatchStream { metadata: Arc, @@ -653,6 +657,70 @@ impl ParquetRecordBatchStream { } } +impl ParquetRecordBatchStream +where + T: AsyncFileReader + Unpin + Send + 'static, +{ + /// Fetches the next row group from the stream. + /// + /// Users can continue to call this function to get row groups and decode them concurrently. + /// + /// ## Notes + /// + /// ParquetRecordBatchStream should be used either as a `Stream` or with `next_row_group`; they should not be used simultaneously. + /// + /// ## Returns + /// + /// - `Ok(None)` if the stream has ended. + /// - `Err(error)` if the stream has errored. All subsequent calls will return `Ok(None)`. + /// - `Ok(Some(reader))` which holds all the data for the row group. + pub async fn next_row_group(&mut self) -> Result> { + loop { + match &mut self.state { + StreamState::Decoding(_) | StreamState::Reading(_) => { + return Err(ParquetError::General( + "Cannot combine the use of next_row_group with the Stream API".to_string(), + )) + } + StreamState::Init => { + let row_group_idx = match self.row_groups.pop_front() { + Some(idx) => idx, + None => return Ok(None), + }; + + let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize; + + let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); + + let reader_factory = self.reader.take().expect("lost reader"); + + let (reader_factory, maybe_reader) = reader_factory + .read_row_group( + row_group_idx, + selection, + self.projection.clone(), + self.batch_size, + ) + .await + .map_err(|err| { + self.state = StreamState::Error; + err + })?; + self.reader = Some(reader_factory); + + if let Some(reader) = maybe_reader { + return Ok(Some(reader)); + } else { + // All rows skipped, read next row group + continue; + } + } + StreamState::Error => return Ok(None), // Ends the stream as error happens. + } + } + } +} + impl Stream for ParquetRecordBatchStream where T: AsyncFileReader + Unpin + Send + 'static, @@ -724,7 +792,7 @@ struct InMemoryRowGroup<'a> { row_count: usize, } -impl<'a> InMemoryRowGroup<'a> { +impl InMemoryRowGroup<'_> { /// Fetches the necessary column data into memory async fn fetch( &mut self, @@ -927,7 +995,6 @@ mod tests { use crate::arrow::schema::parquet_to_arrow_schema_and_fields; use crate::arrow::ArrowWriter; use crate::file::metadata::ParquetMetaDataReader; - use crate::file::page_index::index_reader; use crate::file::properties::WriterProperties; use arrow::compute::kernels::cmp::eq; use arrow::error::Result as ArrowResult; @@ -1020,6 +1087,71 @@ mod tests { ); } + #[tokio::test] + async fn test_async_reader_with_next_row_group() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/alltypes_plain.parquet"); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = ParquetMetaDataReader::new() + .parse_and_finish(&data) + .unwrap(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let requests = async_reader.requests.clone(); + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]); + let mut stream = builder + .with_projection(mask.clone()) + .with_batch_size(1024) + .build() + .unwrap(); + + let mut readers = vec![]; + while let Some(reader) = stream.next_row_group().await.unwrap() { + readers.push(reader); + } + + let async_batches: Vec<_> = readers + .into_iter() + .flat_map(|r| r.map(|v| v.unwrap()).collect::>()) + .collect(); + + let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data) + .unwrap() + .with_projection(mask) + .with_batch_size(104) + .build() + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(async_batches, sync_batches); + + let requests = requests.lock().unwrap(); + let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range(); + let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range(); + + assert_eq!( + &requests[..], + &[ + offset_1 as usize..(offset_1 + length_1) as usize, + offset_2 as usize..(offset_2 + length_2) as usize + ] + ); + } + #[tokio::test] async fn test_async_reader_with_index() { let testdata = arrow::util::test_util::parquet_test_data(); @@ -1565,12 +1697,11 @@ mod tests { let data = Bytes::from(std::fs::read(path).unwrap()); let metadata = ParquetMetaDataReader::new() + .with_page_indexes(true) .parse_and_finish(&data) .unwrap(); - let offset_index = - index_reader::read_offset_indexes(&data, metadata.row_group(0).columns()) - .expect("reading offset index"); + let offset_index = metadata.offset_index().expect("reading offset index")[0].clone(); let mut metadata_builder = metadata.into_builder(); let mut row_groups = metadata_builder.take_row_groups(); @@ -1870,7 +2001,7 @@ mod tests { async fn test_nested_skip() { let schema = Arc::new(Schema::new(vec![ Field::new("col_1", DataType::UInt64, false), - Field::new_list("col_2", Field::new("item", DataType::Utf8, true), true), + Field::new_list("col_2", Field::new_list_field(DataType::Utf8, true), true), ])); // Default writer properties diff --git a/parquet/src/arrow/async_writer/mod.rs b/parquet/src/arrow/async_writer/mod.rs index 8155b57d9ac6..c04d5710a971 100644 --- a/parquet/src/arrow/async_writer/mod.rs +++ b/parquet/src/arrow/async_writer/mod.rs @@ -89,7 +89,7 @@ pub trait AsyncFileWriter: Send { fn complete(&mut self) -> BoxFuture<'_, Result<()>>; } -impl AsyncFileWriter for Box { +impl AsyncFileWriter for Box { fn write(&mut self, bs: Bytes) -> BoxFuture<'_, Result<()>> { self.as_mut().write(bs) } diff --git a/parquet/src/arrow/buffer/view_buffer.rs b/parquet/src/arrow/buffer/view_buffer.rs index 2256f4877d68..fd7d6c213f04 100644 --- a/parquet/src/arrow/buffer/view_buffer.rs +++ b/parquet/src/arrow/buffer/view_buffer.rs @@ -130,7 +130,7 @@ mod tests { #[test] fn test_view_buffer_append_view() { let mut buffer = ViewBuffer::default(); - let string_buffer = Buffer::from(&b"0123456789long string to test string view"[..]); + let string_buffer = Buffer::from(b"0123456789long string to test string view"); let block_id = buffer.append_block(string_buffer); unsafe { @@ -157,7 +157,7 @@ mod tests { #[test] fn test_view_buffer_pad_null() { let mut buffer = ViewBuffer::default(); - let string_buffer = Buffer::from(&b"0123456789long string to test string view"[..]); + let string_buffer = Buffer::from(b"0123456789long string to test string view"); let block_id = buffer.append_block(string_buffer); unsafe { diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index 2d09cd19203f..35f5897c18f8 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -108,17 +108,23 @@ pub mod async_writer; mod record_reader; experimental!(mod schema); +use std::sync::Arc; + pub use self::arrow_writer::ArrowWriter; #[cfg(feature = "async")] pub use self::async_reader::ParquetRecordBatchStreamBuilder; #[cfg(feature = "async")] pub use self::async_writer::AsyncArrowWriter; -use crate::schema::types::SchemaDescriptor; +use crate::schema::types::{SchemaDescriptor, Type}; use arrow_schema::{FieldRef, Schema}; +// continue to export deprecated methods until they are removed +#[allow(deprecated)] +pub use self::schema::arrow_to_parquet_schema; + pub use self::schema::{ - arrow_to_parquet_schema, parquet_to_arrow_field_levels, parquet_to_arrow_schema, - parquet_to_arrow_schema_by_columns, FieldLevels, + add_encoded_arrow_schema_to_metadata, encode_arrow_schema, parquet_to_arrow_field_levels, + parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns, ArrowSchemaConverter, FieldLevels, }; /// Schema metadata key used to store serialized Arrow IPC schema @@ -206,10 +212,114 @@ impl ProjectionMask { Self { mask: Some(mask) } } + // Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`. + fn find_leaves(root: &Arc, parent: Option<&String>, paths: &mut Vec) { + let path = parent + .map(|p| [p, root.name()].join(".")) + .unwrap_or(root.name().to_string()); + if root.is_group() { + for child in root.get_fields() { + Self::find_leaves(child, Some(&path), paths); + } + } else { + // Reached a leaf, add to paths + paths.push(path); + } + } + + /// Create a [`ProjectionMask`] which selects only the named columns + /// + /// All leaf columns that fall below a given name will be selected. For example, given + /// the schema + /// ```ignore + /// message schema { + /// OPTIONAL group a (MAP) { + /// REPEATED group key_value { + /// REQUIRED BYTE_ARRAY key (UTF8); // leaf index 0 + /// OPTIONAL group value (MAP) { + /// REPEATED group key_value { + /// REQUIRED INT32 key; // leaf index 1 + /// REQUIRED BOOLEAN value; // leaf index 2 + /// } + /// } + /// } + /// } + /// REQUIRED INT32 b; // leaf index 3 + /// REQUIRED DOUBLE c; // leaf index 4 + /// } + /// ``` + /// `["a.key_value.value", "c"]` would return leaf columns 1, 2, and 4. `["a"]` would return + /// columns 0, 1, and 2. + /// + /// Note: repeated or out of order indices will not impact the final mask. + /// + /// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`. + pub fn columns<'a>( + schema: &SchemaDescriptor, + names: impl IntoIterator, + ) -> Self { + // first make vector of paths for leaf columns + let mut paths: Vec = vec![]; + for root in schema.root_schema().get_fields() { + Self::find_leaves(root, None, &mut paths); + } + assert_eq!(paths.len(), schema.num_columns()); + + let mut mask = vec![false; schema.num_columns()]; + for name in names { + for idx in 0..schema.num_columns() { + if paths[idx].starts_with(name) { + mask[idx] = true; + } + } + } + + Self { mask: Some(mask) } + } + /// Returns true if the leaf column `leaf_idx` is included by the mask pub fn leaf_included(&self, leaf_idx: usize) -> bool { self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true) } + + /// Union two projection masks + /// + /// Example: + /// ```text + /// mask1 = [true, false, true] + /// mask2 = [false, true, true] + /// union(mask1, mask2) = [true, true, true] + /// ``` + pub fn union(&mut self, other: &Self) { + match (self.mask.as_ref(), other.mask.as_ref()) { + (None, _) | (_, None) => self.mask = None, + (Some(a), Some(b)) => { + debug_assert_eq!(a.len(), b.len()); + let mask = a.iter().zip(b.iter()).map(|(&a, &b)| a || b).collect(); + self.mask = Some(mask); + } + } + } + + /// Intersect two projection masks + /// + /// Example: + /// ```text + /// mask1 = [true, false, true] + /// mask2 = [false, true, true] + /// intersect(mask1, mask2) = [false, false, true] + /// ``` + pub fn intersect(&mut self, other: &Self) { + match (self.mask.as_ref(), other.mask.as_ref()) { + (None, _) => self.mask = other.mask.clone(), + (_, None) => {} + (Some(a), Some(b)) => { + debug_assert_eq!(a.len(), b.len()); + let mask = a.iter().zip(b.iter()).map(|(&a, &b)| a && b).collect(); + self.mask = Some(mask); + } + } + } } /// Lookups up the parquet column by name @@ -242,10 +352,14 @@ mod test { use crate::arrow::ArrowWriter; use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter}; use crate::file::properties::{EnabledStatistics, WriterProperties}; + use crate::schema::parser::parse_message_type; + use crate::schema::types::SchemaDescriptor; use arrow_array::{ArrayRef, Int32Array, RecordBatch}; use bytes::Bytes; use std::sync::Arc; + use super::ProjectionMask; + #[test] // Reproducer for https://github.com/apache/arrow-rs/issues/6464 fn test_metadata_read_write_partial_offset() { @@ -371,4 +485,171 @@ mod test { .unwrap(); Bytes::from(buf) } + + #[test] + fn test_mask_from_column_names() { + let message_type = " + message test_schema { + OPTIONAL group a (MAP) { + REPEATED group key_value { + REQUIRED BYTE_ARRAY key (UTF8); + OPTIONAL group value (MAP) { + REPEATED group key_value { + REQUIRED INT32 key; + REQUIRED BOOLEAN value; + } + } + } + } + REQUIRED INT32 b; + REQUIRED DOUBLE c; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["foo", "bar"]); + assert_eq!(mask.mask.unwrap(), vec![false; 5]); + + let mask = ProjectionMask::columns(&schema, []); + assert_eq!(mask.mask.unwrap(), vec![false; 5]); + + let mask = ProjectionMask::columns(&schema, ["a", "c"]); + assert_eq!(mask.mask.unwrap(), [true, true, true, false, true]); + + let mask = ProjectionMask::columns(&schema, ["a.key_value.key", "c"]); + assert_eq!(mask.mask.unwrap(), [true, false, false, false, true]); + + let mask = ProjectionMask::columns(&schema, ["a.key_value.value", "b"]); + assert_eq!(mask.mask.unwrap(), [false, true, true, true, false]); + + let message_type = " + message test_schema { + OPTIONAL group a (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL BYTE_ARRAY element (UTF8); + } + } + } + } + } + } + REQUIRED INT32 b; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = ProjectionMask::columns(&schema, ["a.list.element", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = + ProjectionMask::columns(&schema, ["a.list.element.list.element.list.element", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = ProjectionMask::columns(&schema, ["b"]); + assert_eq!(mask.mask.unwrap(), [false, true]); + + let message_type = " + message test_schema { + OPTIONAL INT32 a; + OPTIONAL INT32 b; + OPTIONAL INT32 c; + OPTIONAL INT32 d; + OPTIONAL INT32 e; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true, false, false, false]); + + let mask = ProjectionMask::columns(&schema, ["d", "b", "d"]); + assert_eq!(mask.mask.unwrap(), [false, true, false, true, false]); + + let message_type = " + message test_schema { + OPTIONAL INT32 a; + OPTIONAL INT32 b; + OPTIONAL INT32 a; + OPTIONAL INT32 d; + OPTIONAL INT32 e; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "e"]); + assert_eq!(mask.mask.unwrap(), [true, false, true, false, true]); + } + + #[test] + fn test_projection_mask_union() { + let mut mask1 = ProjectionMask { + mask: Some(vec![true, false, true]), + }; + let mask2 = ProjectionMask { + mask: Some(vec![false, true, true]), + }; + mask1.union(&mask2); + assert_eq!(mask1.mask, Some(vec![true, true, true])); + + let mut mask1 = ProjectionMask { mask: None }; + let mask2 = ProjectionMask { + mask: Some(vec![false, true, true]), + }; + mask1.union(&mask2); + assert_eq!(mask1.mask, None); + + let mut mask1 = ProjectionMask { + mask: Some(vec![true, false, true]), + }; + let mask2 = ProjectionMask { mask: None }; + mask1.union(&mask2); + assert_eq!(mask1.mask, None); + + let mut mask1 = ProjectionMask { mask: None }; + let mask2 = ProjectionMask { mask: None }; + mask1.union(&mask2); + assert_eq!(mask1.mask, None); + } + + #[test] + fn test_projection_mask_intersect() { + let mut mask1 = ProjectionMask { + mask: Some(vec![true, false, true]), + }; + let mask2 = ProjectionMask { + mask: Some(vec![false, true, true]), + }; + mask1.intersect(&mask2); + assert_eq!(mask1.mask, Some(vec![false, false, true])); + + let mut mask1 = ProjectionMask { mask: None }; + let mask2 = ProjectionMask { + mask: Some(vec![false, true, true]), + }; + mask1.intersect(&mask2); + assert_eq!(mask1.mask, Some(vec![false, true, true])); + + let mut mask1 = ProjectionMask { + mask: Some(vec![true, false, true]), + }; + let mask2 = ProjectionMask { mask: None }; + mask1.intersect(&mask2); + assert_eq!(mask1.mask, Some(vec![true, false, true])); + + let mut mask1 = ProjectionMask { mask: None }; + let mask2 = ProjectionMask { mask: None }; + mask1.intersect(&mask2); + assert_eq!(mask1.mask, None); + } } diff --git a/parquet/src/arrow/schema/complex.rs b/parquet/src/arrow/schema/complex.rs index e487feabb848..16d46bd852dc 100644 --- a/parquet/src/arrow/schema/complex.rs +++ b/parquet/src/arrow/schema/complex.rs @@ -271,8 +271,13 @@ impl Visitor { return Err(arrow_err!("Child of map field must be repeated")); } + // According to the specification the values are optional (#1642). + // In this case, return the keys as a list. + if map_key_value.get_fields().len() == 1 { + return self.visit_list(map_type, context); + } + if map_key_value.get_fields().len() != 2 { - // According to the specification the values are optional (#1642) return Err(arrow_err!( "Child of map field must have two children, found {}", map_key_value.get_fields().len() @@ -448,15 +453,21 @@ impl Visitor { }; } + // test to see if the repeated field is a struct or one-tuple let items = repeated_field.get_fields(); if items.len() != 1 - || repeated_field.name() == "array" - || repeated_field.name() == format!("{}_tuple", list_type.name()) + || (!repeated_field.is_list() + && !repeated_field.has_single_repeated_child() + && (repeated_field.name() == "array" + || repeated_field.name() == format!("{}_tuple", list_type.name()))) { - // If the repeated field is a group with multiple fields, then its type is the element type and elements are required. + // If the repeated field is a group with multiple fields, then its type is the element + // type and elements are required. // - // If the repeated field is a group with one field and is named either array or uses the LIST-annotated group's name - // with _tuple appended then the repeated type is the element type and elements are required. + // If the repeated field is a group with one field and is named either array or uses + // the LIST-annotated group's name with _tuple appended then the repeated type is the + // element type and elements are required. But this rule only applies if the + // repeated field is not annotated, and the single child field is not `repeated`. let context = VisitorContext { rep_level: context.rep_level, def_level, @@ -541,8 +552,11 @@ fn convert_field(parquet_type: &Type, field: &ParquetField, arrow_hint: Option<& match arrow_hint { Some(hint) => { // If the inferred type is a dictionary, preserve dictionary metadata + #[allow(deprecated)] let field = match (&data_type, hint.dict_id(), hint.dict_is_ordered()) { - (DataType::Dictionary(_, _), Some(id), Some(ordered)) => { + (DataType::Dictionary(_, _), Some(id), Some(ordered)) => + { + #[allow(deprecated)] Field::new_dict(name, data_type, nullable, id, ordered) } _ => Field::new(name, data_type, nullable), diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 3ed3bd24e0a8..8be2439002be 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -15,13 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Provides API for converting parquet schema to arrow schema and vice versa. -//! -//! The main interfaces for converting parquet schema to arrow schema are -//! `parquet_to_arrow_schema`, `parquet_to_arrow_schema_by_columns` and -//! `parquet_to_arrow_field`. -//! -//! The interfaces for converting arrow schema to parquet schema is coming. +//! Converting Parquet schema <--> Arrow schema: [`ArrowSchemaConverter`] and [parquet_to_arrow_schema] use base64::prelude::BASE64_STANDARD; use base64::Engine; @@ -176,8 +170,9 @@ fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Result { } /// Encodes the Arrow schema into the IPC format, and base64 encodes it -fn encode_arrow_schema(schema: &Schema) -> String { +pub fn encode_arrow_schema(schema: &Schema) -> String { let options = writer::IpcWriteOptions::default(); + #[allow(deprecated)] let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(true, options.preserve_dict_id()); let data_gen = writer::IpcDataGenerator::default(); @@ -197,7 +192,7 @@ fn encode_arrow_schema(schema: &Schema) -> String { /// Mutates writer metadata by storing the encoded Arrow schema. /// If there is an existing Arrow schema metadata, it is replaced. -pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut WriterProperties) { +pub fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut WriterProperties) { let encoded = encode_arrow_schema(schema); let schema_kv = KeyValue { @@ -225,23 +220,134 @@ pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut } } +/// Converter for Arrow schema to Parquet schema +/// +/// Example: +/// ``` +/// # use std::sync::Arc; +/// # use arrow_schema::{Field, Schema, DataType}; +/// # use parquet::arrow::ArrowSchemaConverter; +/// use parquet::schema::types::{SchemaDescriptor, Type}; +/// use parquet::basic; // note there are two `Type`s in the following example +/// // create an Arrow Schema +/// let arrow_schema = Schema::new(vec![ +/// Field::new("a", DataType::Int64, true), +/// Field::new("b", DataType::Date32, true), +/// ]); +/// // convert the Arrow schema to a Parquet schema +/// let parquet_schema = ArrowSchemaConverter::new() +/// .convert(&arrow_schema) +/// .unwrap(); +/// +/// let expected_parquet_schema = SchemaDescriptor::new( +/// Arc::new( +/// Type::group_type_builder("arrow_schema") +/// .with_fields(vec![ +/// Arc::new( +/// Type::primitive_type_builder("a", basic::Type::INT64) +/// .build().unwrap() +/// ), +/// Arc::new( +/// Type::primitive_type_builder("b", basic::Type::INT32) +/// .with_converted_type(basic::ConvertedType::DATE) +/// .with_logical_type(Some(basic::LogicalType::Date)) +/// .build().unwrap() +/// ), +/// ]) +/// .build().unwrap() +/// ) +/// ); +/// assert_eq!(parquet_schema, expected_parquet_schema); +/// ``` +#[derive(Debug)] +pub struct ArrowSchemaConverter<'a> { + /// Name of the root schema in Parquet + schema_root: &'a str, + /// Should we coerce Arrow types to compatible Parquet types? + /// + /// See docs on [Self::with_coerce_types]` + coerce_types: bool, +} + +impl Default for ArrowSchemaConverter<'_> { + fn default() -> Self { + Self::new() + } +} + +impl<'a> ArrowSchemaConverter<'a> { + /// Create a new converter + pub fn new() -> Self { + Self { + schema_root: "arrow_schema", + coerce_types: false, + } + } + + /// Should Arrow types be coerced into Parquet native types (default `false`). + /// + /// Setting this option to `true` will result in Parquet files that can be + /// read by more readers, but may lose precision for Arrow types such as + /// [`DataType::Date64`] which have no direct [corresponding Parquet type]. + /// + /// By default, this converter does not coerce to native Parquet types. Enabling type + /// coercion allows for meaningful representations that do not require + /// downstream readers to consider the embedded Arrow schema, and can allow + /// for greater compatibility with other Parquet implementations. However, + /// type coercion also prevents data from being losslessly round-tripped. + /// + /// # Discussion + /// + /// Some Arrow types such as `Date64`, `Timestamp` and `Interval` have no + /// corresponding Parquet logical type. Thus, they can not be losslessly + /// round-tripped when stored using the appropriate Parquet logical type. + /// For example, some Date64 values may be truncated when stored with + /// parquet's native 32 bit date type. + /// + /// For [`List`] and [`Map`] types, some Parquet readers expect certain + /// schema elements to have specific names (earlier versions of the spec + /// were somewhat ambiguous on this point). Type coercion will use the names + /// prescribed by the Parquet specification, potentially losing naming + /// metadata from the Arrow schema. + /// + /// [`List`]: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + /// [`Map`]: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#maps + /// [corresponding Parquet type]: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#date + /// + pub fn with_coerce_types(mut self, coerce_types: bool) -> Self { + self.coerce_types = coerce_types; + self + } + + /// Set the root schema element name (defaults to `"arrow_schema"`). + pub fn schema_root(mut self, schema_root: &'a str) -> Self { + self.schema_root = schema_root; + self + } + + /// Convert the specified Arrow [`Schema`] to the desired Parquet [`SchemaDescriptor`] + /// + /// See example in [`ArrowSchemaConverter`] + pub fn convert(&self, schema: &Schema) -> Result { + let fields = schema + .fields() + .iter() + .map(|field| arrow_to_parquet_type(field, self.coerce_types).map(Arc::new)) + .collect::>()?; + let group = Type::group_type_builder(self.schema_root) + .with_fields(fields) + .build()?; + Ok(SchemaDescriptor::new(Arc::new(group))) + } +} + /// Convert arrow schema to parquet schema /// /// The name of the root schema element defaults to `"arrow_schema"`, this can be -/// overridden with [`arrow_to_parquet_schema_with_root`] +/// overridden with [`ArrowSchemaConverter`] +#[deprecated(since = "54.0.0", note = "Use `ArrowSchemaConverter` instead")] pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { - arrow_to_parquet_schema_with_root(schema, "arrow_schema") -} - -/// Convert arrow schema to parquet schema specifying the name of the root schema element -pub fn arrow_to_parquet_schema_with_root(schema: &Schema, root: &str) -> Result { - let fields = schema - .fields() - .iter() - .map(|field| arrow_to_parquet_type(field).map(Arc::new)) - .collect::>()?; - let group = Type::group_type_builder(root).with_fields(fields).build()?; - Ok(SchemaDescriptor::new(Arc::new(group))) + ArrowSchemaConverter::new().convert(schema) } fn parse_key_value_metadata( @@ -298,7 +404,12 @@ pub fn decimal_length_from_precision(precision: u8) -> usize { } /// Convert an arrow field to a parquet `Type` -fn arrow_to_parquet_type(field: &Field) -> Result { +fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { + const PARQUET_LIST_ELEMENT_NAME: &str = "element"; + const PARQUET_MAP_STRUCT_NAME: &str = "key_value"; + const PARQUET_KEY_FIELD_NAME: &str = "key"; + const PARQUET_VALUE_FIELD_NAME: &str = "value"; + let name = field.name().as_str(); let repetition = if field.is_nullable() { Repetition::OPTIONAL @@ -415,12 +526,20 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - // date64 is cast to date32 (#1666) - DataType::Date64 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::Date)) - .with_repetition(repetition) - .with_id(id) - .build(), + DataType::Date64 => { + if coerce_types { + Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(Some(LogicalType::Date)) + .with_repetition(repetition) + .with_id(id) + .build() + } else { + Type::primitive_type_builder(name, PhysicalType::INT64) + .with_repetition(repetition) + .with_id(id) + .build() + } + } DataType::Time32(TimeUnit::Second) => { // Cannot represent seconds in LogicalType Type::primitive_type_builder(name, PhysicalType::INT32) @@ -515,10 +634,18 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_id(id) .build(), DataType::List(f) | DataType::FixedSizeList(f, _) | DataType::LargeList(f) => { + let field_ref = if coerce_types && f.name() != PARQUET_LIST_ELEMENT_NAME { + // Ensure proper naming per the Parquet specification + let ff = f.as_ref().clone().with_name(PARQUET_LIST_ELEMENT_NAME); + Arc::new(arrow_to_parquet_type(&ff, coerce_types)?) + } else { + Arc::new(arrow_to_parquet_type(f, coerce_types)?) + }; + Type::group_type_builder(name) .with_fields(vec![Arc::new( Type::group_type_builder("list") - .with_fields(vec![Arc::new(arrow_to_parquet_type(f)?)]) + .with_fields(vec![field_ref]) .with_repetition(Repetition::REPEATED) .build()?, )]) @@ -537,7 +664,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { // recursively convert children to types/nodes let fields = fields .iter() - .map(|f| arrow_to_parquet_type(f).map(Arc::new)) + .map(|f| arrow_to_parquet_type(f, coerce_types).map(Arc::new)) .collect::>()?; Type::group_type_builder(name) .with_fields(fields) @@ -547,13 +674,29 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } DataType::Map(field, _) => { if let DataType::Struct(struct_fields) = field.data_type() { + // If coercing then set inner struct name to "key_value" + let map_struct_name = if coerce_types { + PARQUET_MAP_STRUCT_NAME + } else { + field.name() + }; + + // If coercing then ensure struct fields are named "key" and "value" + let fix_map_field = |name: &str, fld: &Arc| -> Result> { + if coerce_types && fld.name() != name { + let f = fld.as_ref().clone().with_name(name); + Ok(Arc::new(arrow_to_parquet_type(&f, coerce_types)?)) + } else { + Ok(Arc::new(arrow_to_parquet_type(fld, coerce_types)?)) + } + }; + let key_field = fix_map_field(PARQUET_KEY_FIELD_NAME, &struct_fields[0])?; + let val_field = fix_map_field(PARQUET_VALUE_FIELD_NAME, &struct_fields[1])?; + Type::group_type_builder(name) .with_fields(vec![Arc::new( - Type::group_type_builder(field.name()) - .with_fields(vec![ - Arc::new(arrow_to_parquet_type(&struct_fields[0])?), - Arc::new(arrow_to_parquet_type(&struct_fields[1])?), - ]) + Type::group_type_builder(map_struct_name) + .with_fields(vec![key_field, val_field]) .with_repetition(Repetition::REPEATED) .build()?, )]) @@ -571,7 +714,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { DataType::Dictionary(_, ref value) => { // Dictionary encoding not handled at the schema level let dict_field = field.clone().with_data_type(value.as_ref().clone()); - arrow_to_parquet_type(&dict_field) + arrow_to_parquet_type(&dict_field, coerce_types) } DataType::RunEndEncoded(_, _) => Err(arrow_err!( "Converting RunEndEncodedType to parquet not supported", @@ -1256,6 +1399,17 @@ mod tests { for i in 0..arrow_fields.len() { assert_eq!(&arrow_fields[i], converted_fields[i].as_ref()); } + + let mask = + ProjectionMask::columns(&parquet_schema, ["group2.leaf4", "group1.leaf1", "leaf5"]); + let converted_arrow_schema = + parquet_to_arrow_schema_by_columns(&parquet_schema, mask, None).unwrap(); + let converted_fields = converted_arrow_schema.fields(); + + assert_eq!(arrow_fields.len(), converted_fields.len()); + for i in 0..arrow_fields.len() { + assert_eq!(&arrow_fields[i], converted_fields[i].as_ref()); + } } #[test] @@ -1408,6 +1562,81 @@ mod tests { assert_eq!(arrow_fields, converted_arrow_fields); } + #[test] + fn test_coerced_map_list() { + // Create Arrow schema with non-Parquet naming + let arrow_fields = vec![ + Field::new_list( + "my_list", + Field::new("item", DataType::Boolean, true), + false, + ), + Field::new_map( + "my_map", + "entries", + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Int32, true), + false, + true, + ), + ]; + let arrow_schema = Schema::new(arrow_fields); + + // Create Parquet schema with coerced names + let message_type = " + message parquet_schema { + REQUIRED GROUP my_list (LIST) { + REPEATED GROUP list { + OPTIONAL BOOLEAN element; + } + } + OPTIONAL GROUP my_map (MAP) { + REPEATED GROUP key_value { + REQUIRED BINARY key (STRING); + OPTIONAL INT32 value; + } + } + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + let converted_arrow_schema = ArrowSchemaConverter::new() + .with_coerce_types(true) + .convert(&arrow_schema) + .unwrap(); + assert_eq!( + parquet_schema.columns().len(), + converted_arrow_schema.columns().len() + ); + + // Create Parquet schema without coerced names + let message_type = " + message parquet_schema { + REQUIRED GROUP my_list (LIST) { + REPEATED GROUP list { + OPTIONAL BOOLEAN item; + } + } + OPTIONAL GROUP my_map (MAP) { + REPEATED GROUP entries { + REQUIRED BINARY keys (STRING); + OPTIONAL INT32 values; + } + } + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + let converted_arrow_schema = ArrowSchemaConverter::new() + .with_coerce_types(false) + .convert(&arrow_schema) + .unwrap(); + assert_eq!( + parquet_schema.columns().len(), + converted_arrow_schema.columns().len() + ); + } + #[test] fn test_field_to_column_desc() { let message_type = " @@ -1557,7 +1786,7 @@ mod tests { Field::new("decimal256", DataType::Decimal256(39, 2), false), ]; let arrow_schema = Schema::new(arrow_fields); - let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema).unwrap(); + let converted_arrow_schema = ArrowSchemaConverter::new().convert(&arrow_schema).unwrap(); assert_eq!( parquet_schema.columns().len(), @@ -1594,9 +1823,10 @@ mod tests { false, )]; let arrow_schema = Schema::new(arrow_fields); - let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema); + let converted_arrow_schema = ArrowSchemaConverter::new() + .with_coerce_types(true) + .convert(&arrow_schema); - assert!(converted_arrow_schema.is_err()); converted_arrow_schema.unwrap(); } @@ -1665,7 +1895,7 @@ mod tests { Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), Field::new_list( "c21", - Field::new("item", DataType::Boolean, true) + Field::new_list_field(DataType::Boolean, true) .with_metadata(meta(&[("Key", "Bar"), (PARQUET_FIELD_ID_META_KEY, "5")])), false, ) @@ -1673,7 +1903,7 @@ mod tests { Field::new( "c22", DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Boolean, true)), + Arc::new(Field::new_list_field(DataType::Boolean, true)), 5, ), false, @@ -1682,8 +1912,7 @@ mod tests { "c23", Field::new_large_list( "inner", - Field::new( - "item", + Field::new_list_field( DataType::Struct( vec![ Field::new("a", DataType::Int16, true), @@ -1714,6 +1943,7 @@ mod tests { // Field::new("c28", DataType::Duration(TimeUnit::Millisecond), false), // Field::new("c29", DataType::Duration(TimeUnit::Microsecond), false), // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), + #[allow(deprecated)] Field::new_dict( "c31", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), @@ -1728,8 +1958,7 @@ mod tests { "c34", Field::new_list( "inner", - Field::new( - "item", + Field::new_list_field( DataType::Struct( vec![ Field::new("a", DataType::Int16, true), @@ -1762,7 +1991,7 @@ mod tests { .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "8")])), Field::new_list( "my_value", - Field::new("item", DataType::Utf8, true) + Field::new_list_field(DataType::Utf8, true) .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "10")])), true, ) @@ -1777,7 +2006,7 @@ mod tests { Field::new("my_key", DataType::Utf8, false), Field::new_list( "my_value", - Field::new("item", DataType::Utf8, true) + Field::new_list_field(DataType::Utf8, true) .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "11")])), true, ), @@ -1868,7 +2097,9 @@ mod tests { // don't pass metadata so field ids are read from Parquet and not from serialized Arrow schema let arrow_schema = crate::arrow::parquet_to_arrow_schema(&schema_descriptor, None)?; - let parq_schema_descr = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; + let parq_schema_descr = ArrowSchemaConverter::new() + .with_coerce_types(true) + .convert(&arrow_schema)?; let parq_fields = parq_schema_descr.root_schema().get_fields(); assert_eq!(parq_fields.len(), 2); assert_eq!(parq_fields[0].get_basic_info().id(), 1); diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 1926b87623bf..99f122fe4c3e 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -302,6 +302,7 @@ pub enum Encoding { /// /// The RLE/bit-packing hybrid is more cpu and memory efficient and should be used instead. #[deprecated( + since = "51.0.0", note = "Please see documentation for compatibility issues and use the RLE/bit-packing hybrid encoding instead" )] BIT_PACKED, @@ -425,14 +426,19 @@ fn split_compression_string(str_setting: &str) -> Result<(&str, Option), Pa fn check_level_is_none(level: &Option) -> Result<(), ParquetError> { if level.is_some() { - return Err(ParquetError::General("level is not support".to_string())); + return Err(ParquetError::General( + "compression level is not supported".to_string(), + )); } Ok(()) } fn require_level(codec: &str, level: Option) -> Result { - level.ok_or(ParquetError::General(format!("{} require level", codec))) + level.ok_or(ParquetError::General(format!( + "{} requires a compression level", + codec + ))) } impl FromStr for Compression { diff --git a/parquet/src/bin/parquet-rewrite.rs b/parquet/src/bin/parquet-rewrite.rs index ad0f7ae0df7d..5a1ec94d5502 100644 --- a/parquet/src/bin/parquet-rewrite.rs +++ b/parquet/src/bin/parquet-rewrite.rs @@ -199,6 +199,10 @@ struct Args { /// Sets writer version. #[clap(long)] writer_version: Option, + + /// Sets whether to coerce Arrow types to match Parquet specification + #[clap(long)] + coerce_types: Option, } fn main() { @@ -238,6 +242,7 @@ fn main() { if let Some(value) = args.dictionary_page_size_limit { writer_properties_builder = writer_properties_builder.set_dictionary_page_size_limit(value); } + #[allow(deprecated)] if let Some(value) = args.max_statistics_size { writer_properties_builder = writer_properties_builder.set_max_statistics_size(value); } @@ -262,6 +267,9 @@ fn main() { if let Some(value) = args.writer_version { writer_properties_builder = writer_properties_builder.set_writer_version(value.into()); } + if let Some(value) = args.coerce_types { + writer_properties_builder = writer_properties_builder.set_coerce_types(value); + } let writer_properties = writer_properties_builder.build(); let mut parquet_writer = ArrowWriter::try_new( File::create(&args.output).expect("Unable to open output file"), diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index 2b43b4c3e45c..953dc057d7a3 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -185,31 +185,6 @@ where } } - /// Reads a batch of values of at most `batch_size`, returning a tuple containing the - /// actual number of non-null values read, followed by the corresponding number of levels, - /// i.e, the total number of values including nulls, empty lists, etc... - /// - /// If the max definition level is 0, `def_levels` will be ignored, otherwise it will be - /// populated with the number of levels read, with an error returned if it is `None`. - /// - /// If the max repetition level is 0, `rep_levels` will be ignored, otherwise it will be - /// populated with the number of levels read, with an error returned if it is `None`. - /// - /// `values` will be contiguously populated with the non-null values. Note that if the column - /// is not required, this may be less than either `batch_size` or the number of levels read - #[deprecated(note = "Use read_records")] - pub fn read_batch( - &mut self, - batch_size: usize, - def_levels: Option<&mut D::Buffer>, - rep_levels: Option<&mut R::Buffer>, - values: &mut V::Buffer, - ) -> Result<(usize, usize)> { - let (_, values, levels) = self.read_records(batch_size, def_levels, rep_levels, values)?; - - Ok((values, levels)) - } - /// Read up to `max_records` whole records, returning the number of complete /// records, non-null values and levels decoded. All levels for a given record /// will be read, i.e. the next repetition level, if any, will be 0 diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 9bd79840f760..5f34f34cbb7a 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -347,7 +347,7 @@ pub struct GenericColumnWriter<'a, E: ColumnValueEncoder> { data_pages: VecDeque, // column index and offset index column_index_builder: ColumnIndexBuilder, - offset_index_builder: OffsetIndexBuilder, + offset_index_builder: Option, // Below fields used to incrementally check boundary order across data pages. // We assume they are ascending/descending until proven wrong. @@ -394,6 +394,12 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { column_index_builder.to_invalid() } + // Disable offset_index_builder if requested by user. + let offset_index_builder = match props.offset_index_disabled() { + false => Some(OffsetIndexBuilder::new()), + _ => None, + }; + Self { descr, props, @@ -408,7 +414,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { page_metrics, column_metrics, column_index_builder, - offset_index_builder: OffsetIndexBuilder::new(), + offset_index_builder, encodings, data_page_boundary_ascending: true, data_page_boundary_descending: true, @@ -568,7 +574,11 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { /// anticipated encoded size. #[cfg(feature = "arrow")] pub(crate) fn get_estimated_total_bytes(&self) -> u64 { - self.column_metrics.total_bytes_written + self.data_pages + .iter() + .map(|page| page.data().len() as u64) + .sum::() + + self.column_metrics.total_bytes_written + self.encoder.estimated_data_page_size() as u64 + self.encoder.estimated_dict_page_size().unwrap_or_default() as u64 } @@ -613,7 +623,8 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { .column_index_builder .valid() .then(|| self.column_index_builder.build_to_thrift()); - let offset_index = Some(self.offset_index_builder.build_to_thrift()); + + let offset_index = self.offset_index_builder.map(|b| b.build_to_thrift()); Ok(ColumnCloseResult { bytes_written: self.column_metrics.total_bytes_written, @@ -841,11 +852,10 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { ); // Update the offset index - self.offset_index_builder - .append_row_count(self.page_metrics.num_buffered_rows as i64); - - self.offset_index_builder - .append_unencoded_byte_array_data_bytes(page_variable_length_bytes); + if let Some(builder) = self.offset_index_builder.as_mut() { + builder.append_row_count(self.page_metrics.num_buffered_rows as i64); + builder.append_unencoded_byte_array_data_bytes(page_variable_length_bytes); + } } /// Determine if we should allow truncating min/max values for this column's statistics @@ -868,24 +878,67 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { } } + /// Returns `true` if this column's logical type is a UTF-8 string. + fn is_utf8(&self) -> bool { + self.get_descriptor().logical_type() == Some(LogicalType::String) + || self.get_descriptor().converted_type() == ConvertedType::UTF8 + } + + /// Truncates a binary statistic to at most `truncation_length` bytes. + /// + /// If truncation is not possible, returns `data`. + /// + /// The `bool` in the returned tuple indicates whether truncation occurred or not. + /// + /// UTF-8 Note: + /// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will + /// also remain valid UTF-8, but may be less tnan `truncation_length` bytes to avoid splitting + /// on non-character boundaries. fn truncate_min_value(&self, truncation_length: Option, data: &[u8]) -> (Vec, bool) { truncation_length .filter(|l| data.len() > *l) - .and_then(|l| match str::from_utf8(data) { - Ok(str_data) => truncate_utf8(str_data, l), - Err(_) => Some(data[..l].to_vec()), - }) + .and_then(|l| + // don't do extra work if this column isn't UTF-8 + if self.is_utf8() { + match str::from_utf8(data) { + Ok(str_data) => truncate_utf8(str_data, l), + Err(_) => Some(data[..l].to_vec()), + } + } else { + Some(data[..l].to_vec()) + } + ) .map(|truncated| (truncated, true)) .unwrap_or_else(|| (data.to_vec(), false)) } + /// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the + /// final byte(s) to yield a valid upper bound. This may result in a result of less than + /// `truncation_length` bytes if the last byte(s) overflows. + /// + /// If truncation is not possible, returns `data`. + /// + /// The `bool` in the returned tuple indicates whether truncation occurred or not. + /// + /// UTF-8 Note: + /// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will + /// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data` + /// does not contain valid UTF-8, then truncation will occur as if the column is non-string + /// binary. fn truncate_max_value(&self, truncation_length: Option, data: &[u8]) -> (Vec, bool) { truncation_length .filter(|l| data.len() > *l) - .and_then(|l| match str::from_utf8(data) { - Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8), - Err(_) => increment(data[..l].to_vec()), - }) + .and_then(|l| + // don't do extra work if this column isn't UTF-8 + if self.is_utf8() { + match str::from_utf8(data) { + Ok(str_data) => truncate_and_increment_utf8(str_data, l), + Err(_) => increment(data[..l].to_vec()), + } + } else { + increment(data[..l].to_vec()) + } + ) .map(|truncated| (truncated, true)) .unwrap_or_else(|| (data.to_vec(), false)) } @@ -1174,8 +1227,10 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { let page_spec = self.page_writer.write_page(page)?; // update offset index // compressed_size = header_size + compressed_data_size - self.offset_index_builder - .append_offset_and_size(page_spec.offset as i64, page_spec.compressed_size as i32); + if let Some(builder) = self.offset_index_builder.as_mut() { + builder + .append_offset_and_size(page_spec.offset as i64, page_spec.compressed_size as i32) + } self.update_metrics_for_page(page_spec); Ok(()) } @@ -1406,13 +1461,50 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool { (a[1..]) > (b[1..]) } -/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string, -/// while being less than `length` bytes and non-empty +/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string, +/// while being less than `length` bytes and non-empty. Returns `None` if truncation +/// is not possible within those constraints. +/// +/// The caller guarantees that data.len() > length. fn truncate_utf8(data: &str, length: usize) -> Option> { let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?; Some(data.as_bytes()[..split].to_vec()) } +/// Truncate a UTF-8 slice and increment it's final character. The returned value is the +/// longest such slice that is still a valid UTF-8 string while being less than `length` +/// bytes and non-empty. Returns `None` if no such transformation is possible. +/// +/// The caller guarantees that data.len() > length. +fn truncate_and_increment_utf8(data: &str, length: usize) -> Option> { + // UTF-8 is max 4 bytes, so start search 3 back from desired length + let lower_bound = length.saturating_sub(3); + let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?; + increment_utf8(data.get(..split)?) +} + +/// Increment the final character in a UTF-8 string in such a way that the returned result +/// is still a valid UTF-8 string. The returned string may be shorter than the input if the +/// last character(s) cannot be incremented (due to overflow or producing invalid code points). +/// Returns `None` if the string cannot be incremented. +/// +/// Note that this implementation will not promote an N-byte code point to (N+1) bytes. +fn increment_utf8(data: &str) -> Option> { + for (idx, original_char) in data.char_indices().rev() { + let original_len = original_char.len_utf8(); + if let Some(next_char) = char::from_u32(original_char as u32 + 1) { + // do not allow increasing byte width of incremented char + if next_char.len_utf8() == original_len { + let mut result = data.as_bytes()[..idx + original_len].to_vec(); + next_char.encode_utf8(&mut result[idx..]); + return Some(result); + } + } + } + + None +} + /// Try and increment the bytes from right to left. /// /// Returns `None` if all bytes are set to `u8::MAX`. @@ -1429,29 +1521,15 @@ fn increment(mut data: Vec) -> Option> { None } -/// Try and increment the the string's bytes from right to left, returning when the result -/// is a valid UTF8 string. Returns `None` when it can't increment any byte. -fn increment_utf8(mut data: Vec) -> Option> { - for idx in (0..data.len()).rev() { - let original = data[idx]; - let (byte, overflow) = original.overflowing_add(1); - if !overflow { - data[idx] = byte; - if str::from_utf8(&data).is_ok() { - return Some(data); - } - data[idx] = original; - } - } - - None -} - #[cfg(test)] mod tests { - use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH; + use crate::{ + file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter}, + schema::parser::parse_message_type, + }; + use core::str; use rand::distributions::uniform::SampleUniform; - use std::sync::Arc; + use std::{fs::File, sync::Arc}; use crate::column::{ page::PageReader, @@ -3128,49 +3206,78 @@ mod tests { #[test] fn test_increment_utf8() { + let test_inc = |o: &str, expected: &str| { + if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) { + // Got the expected result... + assert_eq!(v, expected); + // and it's greater than the original string + assert!(*v > *o); + // Also show that BinaryArray level comparison works here + let mut greater = ByteArray::new(); + greater.set_data(Bytes::from(v)); + let mut original = ByteArray::new(); + original.set_data(Bytes::from(o.as_bytes().to_vec())); + assert!(greater > original); + } else { + panic!("Expected incremented UTF8 string to also be valid."); + } + }; + // Basic ASCII case - let v = increment_utf8("hello".as_bytes().to_vec()).unwrap(); - assert_eq!(&v, "hellp".as_bytes()); + test_inc("hello", "hellp"); - // Also show that BinaryArray level comparison works here - let mut greater = ByteArray::new(); - greater.set_data(Bytes::from(v)); - let mut original = ByteArray::new(); - original.set_data(Bytes::from("hello".as_bytes().to_vec())); - assert!(greater > original); + // 1-byte ending in max 1-byte + test_inc("a\u{7f}", "b"); + + // 1-byte max should not truncate as it would need 2-byte code points + assert!(increment_utf8("\u{7f}\u{7f}").is_none()); // UTF8 string - let s = "❤️🧡💛💚💙💜"; - let v = increment_utf8(s.as_bytes().to_vec()).unwrap(); + test_inc("❤️🧡💛💚💙💜", "❤️🧡💛💚💙💝"); - if let Ok(new) = String::from_utf8(v) { - assert_ne!(&new, s); - assert_eq!(new, "❤️🧡💛💚💙💝"); - assert!(new.as_bytes().last().unwrap() > s.as_bytes().last().unwrap()); - } else { - panic!("Expected incremented UTF8 string to also be valid.") - } + // 2-byte without overflow + test_inc("éééé", "éééê"); - // Max UTF8 character - should be a No-Op - let s = char::MAX.to_string(); - assert_eq!(s.len(), 4); - let v = increment_utf8(s.as_bytes().to_vec()); - assert!(v.is_none()); + // 2-byte that overflows lowest byte + test_inc("\u{ff}\u{ff}", "\u{ff}\u{100}"); + + // 2-byte ending in max 2-byte + test_inc("a\u{7ff}", "b"); + + // Max 2-byte should not truncate as it would need 3-byte code points + assert!(increment_utf8("\u{7ff}\u{7ff}").is_none()); + + // 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these + // characters should render right to left). + test_inc("ࠀࠀ", "ࠀࠁ"); + + // 3-byte ending in max 3-byte + test_inc("a\u{ffff}", "b"); + + // Max 3-byte should not truncate as it would need 4-byte code points + assert!(increment_utf8("\u{ffff}\u{ffff}").is_none()); + + // 4-byte without overflow + test_inc("𐀀𐀀", "𐀀𐀁"); + + // 4-byte ending in max unicode + test_inc("a\u{10ffff}", "b"); - // Handle multi-byte UTF8 characters - let s = "a\u{10ffff}"; - let v = increment_utf8(s.as_bytes().to_vec()); - assert_eq!(&v.unwrap(), "b\u{10ffff}".as_bytes()); + // Max 4-byte should not truncate + assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none()); + + // Skip over surrogate pair range (0xD800..=0xDFFF) + //test_inc("a\u{D7FF}", "a\u{e000}"); + test_inc("a\u{D7FF}", "b"); } #[test] fn test_truncate_utf8() { // No-op let data = "❤️🧡💛💚💙💜"; - let r = truncate_utf8(data, data.as_bytes().len()).unwrap(); - assert_eq!(r.len(), data.as_bytes().len()); + let r = truncate_utf8(data, data.len()).unwrap(); + assert_eq!(r.len(), data.len()); assert_eq!(&r, data.as_bytes()); - println!("len is {}", data.len()); // We slice it away from the UTF8 boundary let r = truncate_utf8(data, 13).unwrap(); @@ -3180,6 +3287,90 @@ mod tests { // One multi-byte code point, and a length shorter than it, so we can't slice it let r = truncate_utf8("\u{0836}", 1); assert!(r.is_none()); + + // Test truncate and increment for max bounds on UTF-8 statistics + // 7-bit (i.e. ASCII) + let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap(); + assert_eq!(&r, "yyyyyyyz".as_bytes()); + + // 2-byte without overflow + let r = truncate_and_increment_utf8("ééééé", 7).unwrap(); + assert_eq!(&r, "ééê".as_bytes()); + + // 2-byte that overflows lowest byte + let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap(); + assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes()); + + // max 2-byte should not truncate as it would need 3-byte code points + let r = truncate_and_increment_utf8("߿߿߿߿߿", 8); + assert!(r.is_none()); + + // 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these + // characters should render right to left). + let r = truncate_and_increment_utf8("ࠀࠀࠀࠀ", 8).unwrap(); + assert_eq!(&r, "ࠀࠁ".as_bytes()); + + // max 3-byte should not truncate as it would need 4-byte code points + let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8); + assert!(r.is_none()); + + // 4-byte without overflow + let r = truncate_and_increment_utf8("𐀀𐀀𐀀𐀀", 9).unwrap(); + assert_eq!(&r, "𐀀𐀁".as_bytes()); + + // max 4-byte should not truncate + let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8); + assert!(r.is_none()); + } + + #[test] + // Check fallback truncation of statistics that should be UTF-8, but aren't + // (see https://github.com/apache/arrow-rs/pull/6870). + fn test_byte_array_truncate_invalid_utf8_statistics() { + let message_type = " + message test_schema { + OPTIONAL BYTE_ARRAY a (UTF8); + } + "; + let schema = Arc::new(parse_message_type(message_type).unwrap()); + + // Create Vec containing non-UTF8 bytes + let data = vec![ByteArray::from(vec![128u8; 32]); 7]; + let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1]; + let file: File = tempfile::tempfile().unwrap(); + let props = Arc::new( + WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) + .set_statistics_truncate_length(Some(8)) + .build(), + ); + + let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap(); + let mut row_group_writer = writer.next_row_group().unwrap(); + + let mut col_writer = row_group_writer.next_column().unwrap().unwrap(); + col_writer + .typed::() + .write_batch(&data, Some(&def_levels), None) + .unwrap(); + col_writer.close().unwrap(); + row_group_writer.close().unwrap(); + let file_metadata = writer.close().unwrap(); + assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some()); + let stats = file_metadata.row_groups[0].columns[0] + .meta_data + .as_ref() + .unwrap() + .statistics + .as_ref() + .unwrap(); + assert!(!stats.is_max_value_exact.unwrap()); + // Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should + // be incremented by 1. + assert_eq!( + stats.max_value, + Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec()) + ); } #[test] @@ -3215,6 +3406,52 @@ mod tests { assert!(column_close_result.column_index.is_none()); } + #[test] + fn test_no_offset_index_when_disabled() { + // Test that offset indexes can be disabled + let descr = Arc::new(get_test_column_descr::(1, 0)); + let props = Arc::new( + WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::None) + .set_offset_index_disabled(true) + .build(), + ); + let column_writer = get_column_writer(descr, props, get_test_page_writer()); + let mut writer = get_typed_column_writer::(column_writer); + + let data = Vec::new(); + let def_levels = vec![0; 10]; + writer.write_batch(&data, Some(&def_levels), None).unwrap(); + writer.flush_data_pages().unwrap(); + + let column_close_result = writer.close().unwrap(); + assert!(column_close_result.offset_index.is_none()); + assert!(column_close_result.column_index.is_none()); + } + + #[test] + fn test_offset_index_overridden() { + // Test that offset indexes are not disabled when gathering page statistics + let descr = Arc::new(get_test_column_descr::(1, 0)); + let props = Arc::new( + WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Page) + .set_offset_index_disabled(true) + .build(), + ); + let column_writer = get_column_writer(descr, props, get_test_page_writer()); + let mut writer = get_typed_column_writer::(column_writer); + + let data = Vec::new(); + let def_levels = vec![0; 10]; + writer.write_batch(&data, Some(&def_levels), None).unwrap(); + writer.flush_data_pages().unwrap(); + + let column_close_result = writer.close().unwrap(); + assert!(column_close_result.offset_index.is_some()); + assert!(column_close_result.column_index.is_some()); + } + #[test] fn test_boundary_order() -> Result<()> { let descr = Arc::new(get_test_column_descr::(1, 0)); @@ -3368,6 +3605,26 @@ mod tests { assert!(stats.max_bytes_opt().is_none()); } + #[test] + #[cfg(feature = "arrow")] + fn test_column_writer_get_estimated_total_bytes() { + let page_writer = get_test_page_writer(); + let props = Default::default(); + let mut writer = get_test_column_writer::(page_writer, 0, 0, props); + assert_eq!(writer.get_estimated_total_bytes(), 0); + + writer.write_batch(&[1, 2, 3, 4], None, None).unwrap(); + writer.add_data_page().unwrap(); + let size_with_one_page = writer.get_estimated_total_bytes(); + assert_eq!(size_with_one_page, 20); + + writer.write_batch(&[5, 6, 7, 8], None, None).unwrap(); + writer.add_data_page().unwrap(); + let size_with_two_pages = writer.get_estimated_total_bytes(); + // different pages have different compressed lengths + assert_eq!(size_with_two_pages, 20 + 21); + } + fn write_multiple_pages( column_descr: &Arc, pages: &[&[Option]], diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index 0c708c126503..d089ba7836e1 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -369,17 +369,17 @@ impl RleDecoder { } #[inline(never)] - pub fn get_batch(&mut self, buffer: &mut [T]) -> Result { + pub fn get_batch(&mut self, buffer: &mut [T]) -> Result { assert!(size_of::() <= 8); let mut values_read = 0; while values_read < buffer.len() { if self.rle_left > 0 { let num_values = cmp::min(buffer.len() - values_read, self.rle_left as usize); + let repeated_value = + T::try_from_le_slice(&self.current_value.as_mut().unwrap().to_ne_bytes())?; for i in 0..num_values { - let repeated_value = - T::try_from_le_slice(&self.current_value.as_mut().unwrap().to_ne_bytes())?; - buffer[values_read + i] = repeated_value; + buffer[values_read + i] = repeated_value.clone(); } self.rle_left -= num_values as u32; values_read += num_values; diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs index 6adbffa2a2e5..d749287bba62 100644 --- a/parquet/src/errors.rs +++ b/parquet/src/errors.rs @@ -17,6 +17,7 @@ //! Common Parquet errors and macros. +use core::num::TryFromIntError; use std::error::Error; use std::{cell, io, result, str}; @@ -27,6 +28,7 @@ use arrow_schema::ArrowError; // Note: we don't implement PartialEq as the semantics for the // external variant are not well defined (#4469) #[derive(Debug)] +#[non_exhaustive] pub enum ParquetError { /// General Parquet error. /// Returned when code violates normal workflow of working with Parquet files. @@ -47,6 +49,9 @@ pub enum ParquetError { IndexOutOfBound(usize, usize), /// An external error variant External(Box), + /// Returned when a function needs more data to complete properly. The `usize` field indicates + /// the total number of bytes required, not the number of additional bytes. + NeedMoreData(usize), } impl std::fmt::Display for ParquetError { @@ -63,6 +68,7 @@ impl std::fmt::Display for ParquetError { write!(fmt, "Index {index} out of bound: {bound}") } ParquetError::External(e) => write!(fmt, "External: {e}"), + ParquetError::NeedMoreData(needed) => write!(fmt, "NeedMoreData: {needed}"), } } } @@ -76,6 +82,12 @@ impl Error for ParquetError { } } +impl From for ParquetError { + fn from(e: TryFromIntError) -> ParquetError { + ParquetError::General(format!("Integer overflow: {e}")) + } +} + impl From for ParquetError { fn from(e: io::Error) -> ParquetError { ParquetError::External(Box::new(e)) diff --git a/parquet/src/file/metadata/mod.rs b/parquet/src/file/metadata/mod.rs index 32b985710023..252cb99f3f36 100644 --- a/parquet/src/file/metadata/mod.rs +++ b/parquet/src/file/metadata/mod.rs @@ -190,7 +190,7 @@ impl ParquetMetaData { /// Creates Parquet metadata from file metadata, a list of row /// group metadata, and the column index structures. - #[deprecated(note = "Use ParquetMetaDataBuilder")] + #[deprecated(since = "53.1.0", note = "Use ParquetMetaDataBuilder")] pub fn new_with_page_index( file_metadata: FileMetaData, row_groups: Vec, @@ -230,12 +230,6 @@ impl ParquetMetaData { &self.row_groups } - /// Returns page indexes in this file. - #[deprecated(note = "Use Self::column_index")] - pub fn page_indexes(&self) -> Option<&ParquetColumnIndex> { - self.column_index.as_ref() - } - /// Returns the column index for this file if loaded /// /// Returns `None` if the parquet file does not have a `ColumnIndex` or @@ -246,12 +240,6 @@ impl ParquetMetaData { self.column_index.as_ref() } - /// Returns the offset index for this file if loaded - #[deprecated(note = "Use Self::offset_index")] - pub fn offset_indexes(&self) -> Option<&ParquetOffsetIndex> { - self.offset_index.as_ref() - } - /// Returns offset indexes in this file, if loaded /// /// Returns `None` if the parquet file does not have a `OffsetIndex` or diff --git a/parquet/src/file/metadata/reader.rs b/parquet/src/file/metadata/reader.rs index 2a927f15fb64..c6715a33b5ae 100644 --- a/parquet/src/file/metadata/reader.rs +++ b/parquet/src/file/metadata/reader.rs @@ -178,8 +178,10 @@ impl ParquetMetaDataReader { /// /// # Errors /// - /// This function will return [`ParquetError::IndexOutOfBound`] in the event `reader` does not - /// provide enough data to fully parse the metadata (see example below). + /// This function will return [`ParquetError::NeedMoreData`] in the event `reader` does not + /// provide enough data to fully parse the metadata (see example below). The returned error + /// will be populated with a `usize` field indicating the number of bytes required from the + /// tail of the file to completely parse the requested metadata. /// /// Other errors returned include [`ParquetError::General`] and [`ParquetError::EOF`]. /// @@ -192,11 +194,13 @@ impl ParquetMetaDataReader { /// # fn open_parquet_file(path: &str) -> std::fs::File { unimplemented!(); } /// let file = open_parquet_file("some_path.parquet"); /// let len = file.len() as usize; - /// let bytes = get_bytes(&file, 1000..len); + /// // Speculatively read 1 kilobyte from the end of the file + /// let bytes = get_bytes(&file, len - 1024..len); /// let mut reader = ParquetMetaDataReader::new().with_page_indexes(true); /// match reader.try_parse_sized(&bytes, len) { /// Ok(_) => (), - /// Err(ParquetError::IndexOutOfBound(needed, _)) => { + /// Err(ParquetError::NeedMoreData(needed)) => { + /// // Read the needed number of bytes from the end of the file /// let bytes = get_bytes(&file, len - needed..len); /// reader.try_parse_sized(&bytes, len).unwrap(); /// } @@ -204,15 +208,44 @@ impl ParquetMetaDataReader { /// } /// let metadata = reader.finish().unwrap(); /// ``` + /// + /// Note that it is possible for the file metadata to be completely read, but there are + /// insufficient bytes available to read the page indexes. [`Self::has_metadata()`] can be used + /// to test for this. In the event the file metadata is present, re-parsing of the file + /// metadata can be skipped by using [`Self::read_page_indexes_sized()`], as shown below. + /// ```no_run + /// # use parquet::file::metadata::ParquetMetaDataReader; + /// # use parquet::errors::ParquetError; + /// # use crate::parquet::file::reader::Length; + /// # fn get_bytes(file: &std::fs::File, range: std::ops::Range) -> bytes::Bytes { unimplemented!(); } + /// # fn open_parquet_file(path: &str) -> std::fs::File { unimplemented!(); } + /// let file = open_parquet_file("some_path.parquet"); + /// let len = file.len() as usize; + /// // Speculatively read 1 kilobyte from the end of the file + /// let mut bytes = get_bytes(&file, len - 1024..len); + /// let mut reader = ParquetMetaDataReader::new().with_page_indexes(true); + /// // Loop until `bytes` is large enough + /// loop { + /// match reader.try_parse_sized(&bytes, len) { + /// Ok(_) => break, + /// Err(ParquetError::NeedMoreData(needed)) => { + /// // Read the needed number of bytes from the end of the file + /// bytes = get_bytes(&file, len - needed..len); + /// // If file metadata was read only read page indexes, otherwise continue loop + /// if reader.has_metadata() { + /// reader.read_page_indexes_sized(&bytes, len); + /// break; + /// } + /// } + /// _ => panic!("unexpected error") + /// } + /// } + /// let metadata = reader.finish().unwrap(); + /// ``` pub fn try_parse_sized(&mut self, reader: &R, file_size: usize) -> Result<()> { self.metadata = match self.parse_metadata(reader) { Ok(metadata) => Some(metadata), - // FIXME: throughout this module ParquetError::IndexOutOfBound is used to indicate the - // need for more data. This is not it's intended use. The plan is to add a NeedMoreData - // value to the enum, but this would be a breaking change. This will be done as - // 54.0.0 draws nearer. - // https://github.com/apache/arrow-rs/issues/6447 - Err(ParquetError::IndexOutOfBound(needed, _)) => { + Err(ParquetError::NeedMoreData(needed)) => { // If reader is the same length as `file_size` then presumably there is no more to // read, so return an EOF error. if file_size == reader.len() as usize || needed > file_size { @@ -223,7 +256,7 @@ impl ParquetMetaDataReader { )); } else { // Ask for a larger buffer - return Err(ParquetError::IndexOutOfBound(needed, file_size)); + return Err(ParquetError::NeedMoreData(needed)); } } Err(e) => return Err(e), @@ -246,7 +279,8 @@ impl ParquetMetaDataReader { /// Read the page index structures when a [`ParquetMetaData`] has already been obtained. /// This variant is used when `reader` cannot access the entire Parquet file (e.g. it is /// a [`Bytes`] struct containing the tail of the file). - /// See [`Self::new_with_metadata()`] and [`Self::has_metadata()`]. + /// See [`Self::new_with_metadata()`] and [`Self::has_metadata()`]. Like + /// [`Self::try_parse_sized()`] this function may return [`ParquetError::NeedMoreData`]. pub fn read_page_indexes_sized( &mut self, reader: &R, @@ -269,7 +303,6 @@ impl ParquetMetaDataReader { // Get bounds needed for page indexes (if any are present in the file). let Some(range) = self.range_for_page_index() else { - self.empty_page_indexes(); return Ok(()); }; @@ -285,10 +318,7 @@ impl ParquetMetaDataReader { )); } else { // Ask for a larger buffer - return Err(ParquetError::IndexOutOfBound( - file_size - range.start, - file_size, - )); + return Err(ParquetError::NeedMoreData(file_size - range.start)); } } @@ -446,20 +476,6 @@ impl ParquetMetaDataReader { Ok(()) } - /// Set the column_index and offset_indexes to empty `Vec` for backwards compatibility - /// - /// See for details - fn empty_page_indexes(&mut self) { - let metadata = self.metadata.as_mut().unwrap(); - let num_row_groups = metadata.num_row_groups(); - if self.column_index { - metadata.set_column_index(Some(vec![vec![]; num_row_groups])); - } - if self.offset_index { - metadata.set_offset_index(Some(vec![vec![]; num_row_groups])); - } - } - fn range_for_page_index(&self) -> Option> { // sanity check self.metadata.as_ref()?; @@ -484,10 +500,7 @@ impl ParquetMetaDataReader { // check file is large enough to hold footer let file_size = chunk_reader.len(); if file_size < (FOOTER_SIZE as u64) { - return Err(ParquetError::IndexOutOfBound( - FOOTER_SIZE, - file_size as usize, - )); + return Err(ParquetError::NeedMoreData(FOOTER_SIZE)); } let mut footer = [0_u8; 8]; @@ -500,10 +513,7 @@ impl ParquetMetaDataReader { self.metadata_size = Some(footer_metadata_len); if footer_metadata_len > file_size as usize { - return Err(ParquetError::IndexOutOfBound( - footer_metadata_len, - file_size as usize, - )); + return Err(ParquetError::NeedMoreData(footer_metadata_len)); } let start = file_size - footer_metadata_len as u64; @@ -617,7 +627,8 @@ impl ParquetMetaDataReader { for rg in t_file_metadata.row_groups { row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?); } - let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr); + let column_orders = + Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr)?; let file_metadata = FileMetaData::new( t_file_metadata.version, @@ -635,15 +646,13 @@ impl ParquetMetaDataReader { fn parse_column_orders( t_column_orders: Option>, schema_descr: &SchemaDescriptor, - ) -> Option> { + ) -> Result>> { match t_column_orders { Some(orders) => { // Should always be the case - assert_eq!( - orders.len(), - schema_descr.num_columns(), - "Column order length mismatch" - ); + if orders.len() != schema_descr.num_columns() { + return Err(general_err!("Column order length mismatch")); + }; let mut res = Vec::new(); for (i, column) in schema_descr.columns().iter().enumerate() { match orders[i] { @@ -657,9 +666,9 @@ impl ParquetMetaDataReader { } } } - Some(res) + Ok(Some(res)) } - None => None, + None => Ok(None), } } } @@ -682,7 +691,7 @@ mod tests { let err = ParquetMetaDataReader::new() .parse_metadata(&test_file) .unwrap_err(); - assert!(matches!(err, ParquetError::IndexOutOfBound(8, _))); + assert!(matches!(err, ParquetError::NeedMoreData(8))); } #[test] @@ -701,7 +710,7 @@ mod tests { let err = ParquetMetaDataReader::new() .parse_metadata(&test_file) .unwrap_err(); - assert!(matches!(err, ParquetError::IndexOutOfBound(263, _))); + assert!(matches!(err, ParquetError::NeedMoreData(263))); } #[test] @@ -731,7 +740,7 @@ mod tests { ]); assert_eq!( - ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr), + ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr).unwrap(), Some(vec![ ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED), ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED) @@ -740,20 +749,21 @@ mod tests { // Test when no column orders are defined. assert_eq!( - ParquetMetaDataReader::parse_column_orders(None, &schema_descr), + ParquetMetaDataReader::parse_column_orders(None, &schema_descr).unwrap(), None ); } #[test] - #[should_panic(expected = "Column order length mismatch")] fn test_metadata_column_orders_len_mismatch() { let schema = SchemaType::group_type_builder("schema").build().unwrap(); let schema_descr = SchemaDescriptor::new(Arc::new(schema)); let t_column_orders = Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]); - ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr); + let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr); + assert!(res.is_err()); + assert!(format!("{:?}", res.unwrap_err()).contains("Column order length mismatch")); } #[test] @@ -794,7 +804,7 @@ mod tests { // should fail match reader.try_parse_sized(&bytes, len).unwrap_err() { // expected error, try again with provided bounds - ParquetError::IndexOutOfBound(needed, _) => { + ParquetError::NeedMoreData(needed) => { let bytes = bytes_for_range(len - needed..len); reader.try_parse_sized(&bytes, len).unwrap(); let metadata = reader.finish().unwrap(); @@ -804,6 +814,26 @@ mod tests { _ => panic!("unexpected error"), }; + // not enough for file metadata, but keep trying until page indexes are read + let mut reader = ParquetMetaDataReader::new().with_page_indexes(true); + let mut bytes = bytes_for_range(452505..len); + loop { + match reader.try_parse_sized(&bytes, len) { + Ok(_) => break, + Err(ParquetError::NeedMoreData(needed)) => { + bytes = bytes_for_range(len - needed..len); + if reader.has_metadata() { + reader.read_page_indexes_sized(&bytes, len).unwrap(); + break; + } + } + _ => panic!("unexpected error"), + } + } + let metadata = reader.finish().unwrap(); + assert!(metadata.column_index.is_some()); + assert!(metadata.offset_index.is_some()); + // not enough for page index but lie about file size let bytes = bytes_for_range(323584..len); let reader_result = reader.try_parse_sized(&bytes, len - 323584).unwrap_err(); @@ -818,7 +848,7 @@ mod tests { // should fail match reader.try_parse_sized(&bytes, len).unwrap_err() { // expected error, try again with provided bounds - ParquetError::IndexOutOfBound(needed, _) => { + ParquetError::NeedMoreData(needed) => { let bytes = bytes_for_range(len - needed..len); reader.try_parse_sized(&bytes, len).unwrap(); reader.finish().unwrap(); diff --git a/parquet/src/file/page_index/index_reader.rs b/parquet/src/file/page_index/index_reader.rs index 395e9afe122c..fd3639ac3069 100644 --- a/parquet/src/file/page_index/index_reader.rs +++ b/parquet/src/file/page_index/index_reader.rs @@ -43,8 +43,7 @@ pub(crate) fn acc_range(a: Option>, b: Option>) -> Opt /// /// Returns a vector of `index[column_number]`. /// -/// Returns an empty vector if this row group does not contain a -/// [`ColumnIndex`]. +/// Returns `None` if this row group does not contain a [`ColumnIndex`]. /// /// See [Page Index Documentation] for more details. /// @@ -52,26 +51,29 @@ pub(crate) fn acc_range(a: Option>, b: Option>) -> Opt pub fn read_columns_indexes( reader: &R, chunks: &[ColumnChunkMetaData], -) -> Result, ParquetError> { +) -> Result>, ParquetError> { let fetch = chunks .iter() .fold(None, |range, c| acc_range(range, c.column_index_range())); let fetch = match fetch { Some(r) => r, - None => return Ok(vec![Index::NONE; chunks.len()]), + None => return Ok(None), }; let bytes = reader.get_bytes(fetch.start as _, fetch.end - fetch.start)?; let get = |r: Range| &bytes[(r.start - fetch.start)..(r.end - fetch.start)]; - chunks - .iter() - .map(|c| match c.column_index_range() { - Some(r) => decode_column_index(get(r), c.column_type()), - None => Ok(Index::NONE), - }) - .collect() + Some( + chunks + .iter() + .map(|c| match c.column_index_range() { + Some(r) => decode_column_index(get(r), c.column_type()), + None => Ok(Index::NONE), + }) + .collect(), + ) + .transpose() } /// Reads [`OffsetIndex`], per-page [`PageLocation`] for all columns of a row @@ -116,8 +118,7 @@ pub fn read_pages_locations( /// /// Returns a vector of `offset_index[column_number]`. /// -/// Returns an empty vector if this row group does not contain an -/// [`OffsetIndex`]. +/// Returns `None` if this row group does not contain an [`OffsetIndex`]. /// /// See [Page Index Documentation] for more details. /// @@ -125,26 +126,29 @@ pub fn read_pages_locations( pub fn read_offset_indexes( reader: &R, chunks: &[ColumnChunkMetaData], -) -> Result, ParquetError> { +) -> Result>, ParquetError> { let fetch = chunks .iter() .fold(None, |range, c| acc_range(range, c.offset_index_range())); let fetch = match fetch { Some(r) => r, - None => return Ok(vec![]), + None => return Ok(None), }; let bytes = reader.get_bytes(fetch.start as _, fetch.end - fetch.start)?; let get = |r: Range| &bytes[(r.start - fetch.start)..(r.end - fetch.start)]; - chunks - .iter() - .map(|c| match c.offset_index_range() { - Some(r) => decode_offset_index(get(r)), - None => Err(general_err!("missing offset index")), - }) - .collect() + Some( + chunks + .iter() + .map(|c| match c.offset_index_range() { + Some(r) => decode_offset_index(get(r)), + None => Err(general_err!("missing offset index")), + }) + .collect(), + ) + .transpose() } pub(crate) fn decode_offset_index(data: &[u8]) -> Result { diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index efcb63258f99..dc918f6b5634 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -16,14 +16,13 @@ // under the License. //! Configuration via [`WriterProperties`] and [`ReaderProperties`] -use std::str::FromStr; -use std::{collections::HashMap, sync::Arc}; - use crate::basic::{Compression, Encoding}; use crate::compression::{CodecOptions, CodecOptionsBuilder}; use crate::file::metadata::KeyValue; use crate::format::SortingColumn; use crate::schema::types::ColumnPath; +use std::str::FromStr; +use std::{collections::HashMap, sync::Arc}; /// Default value for [`WriterProperties::data_page_size_limit`] pub const DEFAULT_PAGE_SIZE: usize = 1024 * 1024; @@ -42,6 +41,7 @@ pub const DEFAULT_DATA_PAGE_ROW_COUNT_LIMIT: usize = 20_000; /// Default value for [`WriterProperties::statistics_enabled`] pub const DEFAULT_STATISTICS_ENABLED: EnabledStatistics = EnabledStatistics::Page; /// Default value for [`WriterProperties::max_statistics_size`] +#[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] pub const DEFAULT_MAX_STATISTICS_SIZE: usize = 4096; /// Default value for [`WriterProperties::max_row_group_size`] pub const DEFAULT_MAX_ROW_GROUP_SIZE: usize = 1024 * 1024; @@ -57,6 +57,10 @@ pub const DEFAULT_BLOOM_FILTER_FPP: f64 = 0.05; pub const DEFAULT_BLOOM_FILTER_NDV: u64 = 1_000_000_u64; /// Default values for [`WriterProperties::statistics_truncate_length`] pub const DEFAULT_STATISTICS_TRUNCATE_LENGTH: Option = None; +/// Default value for [`WriterProperties::offset_index_disabled`] +pub const DEFAULT_OFFSET_INDEX_DISABLED: bool = false; +/// Default values for [`WriterProperties::coerce_types`] +pub const DEFAULT_COERCE_TYPES: bool = false; /// Parquet writer version. /// @@ -157,12 +161,14 @@ pub struct WriterProperties { bloom_filter_position: BloomFilterPosition, writer_version: WriterVersion, created_by: String, + offset_index_disabled: bool, pub(crate) key_value_metadata: Option>, default_column_properties: ColumnProperties, column_properties: HashMap, sorting_columns: Option>, column_index_truncate_length: Option, statistics_truncate_length: Option, + coerce_types: bool, } impl Default for WriterProperties { @@ -185,14 +191,6 @@ impl WriterProperties { WriterPropertiesBuilder::with_defaults() } - /// Returns data page size limit. - /// - /// Note: this is a best effort limit based on the write batch size - #[deprecated(since = "41.0.0", note = "Use data_page_size_limit")] - pub fn data_pagesize_limit(&self) -> usize { - self.data_page_size_limit - } - /// Returns data page size limit. /// /// Note: this is a best effort limit based on the write batch size @@ -202,14 +200,6 @@ impl WriterProperties { self.data_page_size_limit } - /// Returns dictionary page size limit. - /// - /// Note: this is a best effort limit based on the write batch size - #[deprecated(since = "41.0.0", note = "Use dictionary_page_size_limit")] - pub fn dictionary_pagesize_limit(&self) -> usize { - self.dictionary_page_size_limit - } - /// Returns dictionary page size limit. /// /// Note: this is a best effort limit based on the write batch size @@ -257,6 +247,22 @@ impl WriterProperties { &self.created_by } + /// Returns `true` if offset index writing is disabled. + pub fn offset_index_disabled(&self) -> bool { + // If page statistics are to be collected, then do not disable the offset indexes. + let default_page_stats_enabled = + self.default_column_properties.statistics_enabled() == Some(EnabledStatistics::Page); + let column_page_stats_enabled = self + .column_properties + .iter() + .any(|path_props| path_props.1.statistics_enabled() == Some(EnabledStatistics::Page)); + if default_page_stats_enabled || column_page_stats_enabled { + return false; + } + + self.offset_index_disabled + } + /// Returns `key_value_metadata` KeyValue pairs. pub fn key_value_metadata(&self) -> Option<&Vec> { self.key_value_metadata.as_ref() @@ -281,6 +287,11 @@ impl WriterProperties { self.statistics_truncate_length } + /// Returns `true` if type coercion is enabled. + pub fn coerce_types(&self) -> bool { + self.coerce_types + } + /// Returns encoding for a data page, when dictionary encoding is enabled. /// This is not configurable. #[inline] @@ -340,7 +351,9 @@ impl WriterProperties { /// Returns max size for statistics. /// Only applicable if statistics are enabled. + #[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] pub fn max_statistics_size(&self, col: &ColumnPath) -> usize { + #[allow(deprecated)] self.column_properties .get(col) .and_then(|c| c.max_statistics_size()) @@ -371,12 +384,14 @@ pub struct WriterPropertiesBuilder { bloom_filter_position: BloomFilterPosition, writer_version: WriterVersion, created_by: String, + offset_index_disabled: bool, key_value_metadata: Option>, default_column_properties: ColumnProperties, column_properties: HashMap, sorting_columns: Option>, column_index_truncate_length: Option, statistics_truncate_length: Option, + coerce_types: bool, } impl WriterPropertiesBuilder { @@ -391,12 +406,14 @@ impl WriterPropertiesBuilder { bloom_filter_position: DEFAULT_BLOOM_FILTER_POSITION, writer_version: DEFAULT_WRITER_VERSION, created_by: DEFAULT_CREATED_BY.to_string(), + offset_index_disabled: DEFAULT_OFFSET_INDEX_DISABLED, key_value_metadata: None, default_column_properties: Default::default(), column_properties: HashMap::new(), sorting_columns: None, column_index_truncate_length: DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, statistics_truncate_length: DEFAULT_STATISTICS_TRUNCATE_LENGTH, + coerce_types: DEFAULT_COERCE_TYPES, } } @@ -411,12 +428,14 @@ impl WriterPropertiesBuilder { bloom_filter_position: self.bloom_filter_position, writer_version: self.writer_version, created_by: self.created_by, + offset_index_disabled: self.offset_index_disabled, key_value_metadata: self.key_value_metadata, default_column_properties: self.default_column_properties, column_properties: self.column_properties, sorting_columns: self.sorting_columns, column_index_truncate_length: self.column_index_truncate_length, statistics_truncate_length: self.statistics_truncate_length, + coerce_types: self.coerce_types, } } @@ -433,16 +452,6 @@ impl WriterPropertiesBuilder { self } - /// Sets best effort maximum size of a data page in bytes. - /// - /// Note: this is a best effort limit based on value of - /// [`set_write_batch_size`](Self::set_write_batch_size). - #[deprecated(since = "41.0.0", note = "Use set_data_page_size_limit")] - pub fn set_data_pagesize_limit(mut self, value: usize) -> Self { - self.data_page_size_limit = value; - self - } - /// Sets best effort maximum size of a data page in bytes (defaults to `1024 * 1024`). /// /// The parquet writer will attempt to limit the sizes of each @@ -471,16 +480,6 @@ impl WriterPropertiesBuilder { self } - /// Sets best effort maximum dictionary page size, in bytes. - /// - /// Note: this is a best effort limit based on value of - /// [`set_write_batch_size`](Self::set_write_batch_size). - #[deprecated(since = "41.0.0", note = "Use set_dictionary_page_size_limit")] - pub fn set_dictionary_pagesize_limit(mut self, value: usize) -> Self { - self.dictionary_page_size_limit = value; - self - } - /// Sets best effort maximum dictionary page size, in bytes (defaults to `1024 * 1024`). /// /// The parquet writer will attempt to limit the size of each @@ -532,6 +531,21 @@ impl WriterPropertiesBuilder { self } + /// Sets whether the writing of offset indexes is disabled (defaults to `false`). + /// + /// If statistics level is set to [`Page`] this setting will be overridden with `false`. + /// + /// Note: As the offset indexes are useful for accessing data by row number, + /// they are always written by default, regardless of whether other statistics + /// are enabled. Disabling this metadata may result in a degradation in read + /// performance, so use this option with care. + /// + /// [`Page`]: EnabledStatistics::Page + pub fn set_offset_index_disabled(mut self, value: bool) -> Self { + self.offset_index_disabled = value; + self + } + /// Sets "key_value_metadata" property (defaults to `None`). pub fn set_key_value_metadata(mut self, value: Option>) -> Self { self.key_value_metadata = value; @@ -590,7 +604,9 @@ impl WriterPropertiesBuilder { /// Sets default max statistics size for all columns (defaults to `4096`). /// /// Applicable only if statistics are enabled. + #[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] pub fn set_max_statistics_size(mut self, value: usize) -> Self { + #[allow(deprecated)] self.default_column_properties .set_max_statistics_size(value); self @@ -695,7 +711,9 @@ impl WriterPropertiesBuilder { /// Sets max size for statistics for a specific column. /// /// Takes precedence over [`Self::set_max_statistics_size`]. + #[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] pub fn set_column_max_statistics_size(mut self, col: ColumnPath, value: usize) -> Self { + #[allow(deprecated)] self.get_mut_props(col).set_max_statistics_size(value); self } @@ -767,6 +785,29 @@ impl WriterPropertiesBuilder { self.statistics_truncate_length = max_length; self } + + /// Should the writer coerce types to parquet native types (defaults to `false`). + /// + /// Leaving this option the default `false` will ensure the exact same data + /// written to parquet using this library will be read. + /// + /// Setting this option to `true` will result in parquet files that can be + /// read by more readers, but potentially lose information in the process. + /// + /// * Types such as [`DataType::Date64`], which have no direct corresponding + /// Parquet type, may be stored with lower precision. + /// + /// * The internal field names of `List` and `Map` types will be renamed if + /// necessary to match what is required by the newest Parquet specification. + /// + /// See [`ArrowToParquetSchemaConverter::with_coerce_types`] for more details + /// + /// [`DataType::Date64`]: arrow_schema::DataType::Date64 + /// [`ArrowToParquetSchemaConverter::with_coerce_types`]: crate::arrow::ArrowSchemaConverter::with_coerce_types + pub fn set_coerce_types(mut self, coerce_types: bool) -> Self { + self.coerce_types = coerce_types; + self + } } /// Controls the level of statistics to be computed by the writer and stored in @@ -862,6 +903,7 @@ struct ColumnProperties { codec: Option, dictionary_enabled: Option, statistics_enabled: Option, + #[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] max_statistics_size: Option, /// bloom filter related properties bloom_filter_properties: Option, @@ -894,12 +936,14 @@ impl ColumnProperties { self.dictionary_enabled = Some(enabled); } - /// Sets whether or not statistics are enabled for this column. + /// Sets the statistics level for this column. fn set_statistics_enabled(&mut self, enabled: EnabledStatistics) { self.statistics_enabled = Some(enabled); } /// Sets max size for statistics for this column. + #[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] + #[allow(deprecated)] fn set_max_statistics_size(&mut self, value: usize) { self.max_statistics_size = Some(value); } @@ -957,14 +1001,16 @@ impl ColumnProperties { self.dictionary_enabled } - /// Returns `Some(true)` if statistics are enabled for this column, if disabled then - /// returns `Some(false)`. If result is `None`, then no setting has been provided. + /// Returns optional statistics level requested for this column. If result is `None`, + /// then no setting has been provided. fn statistics_enabled(&self) -> Option { self.statistics_enabled } /// Returns optional max size in bytes for statistics. + #[deprecated(since = "54.0.0", note = "Unused; will be removed in 56.0.0")] fn max_statistics_size(&self) -> Option { + #[allow(deprecated)] self.max_statistics_size } @@ -1108,10 +1154,6 @@ mod tests { props.statistics_enabled(&ColumnPath::from("col")), DEFAULT_STATISTICS_ENABLED ); - assert_eq!( - props.max_statistics_size(&ColumnPath::from("col")), - DEFAULT_MAX_STATISTICS_SIZE - ); assert!(props .bloom_filter_properties(&ColumnPath::from("col")) .is_none()); @@ -1188,13 +1230,11 @@ mod tests { .set_compression(Compression::GZIP(Default::default())) .set_dictionary_enabled(false) .set_statistics_enabled(EnabledStatistics::None) - .set_max_statistics_size(50) // specific column settings .set_column_encoding(ColumnPath::from("col"), Encoding::RLE) .set_column_compression(ColumnPath::from("col"), Compression::SNAPPY) .set_column_dictionary_enabled(ColumnPath::from("col"), true) .set_column_statistics_enabled(ColumnPath::from("col"), EnabledStatistics::Chunk) - .set_column_max_statistics_size(ColumnPath::from("col"), 123) .set_column_bloom_filter_enabled(ColumnPath::from("col"), true) .set_column_bloom_filter_ndv(ColumnPath::from("col"), 100_u64) .set_column_bloom_filter_fpp(ColumnPath::from("col"), 0.1) @@ -1226,7 +1266,6 @@ mod tests { props.statistics_enabled(&ColumnPath::from("a")), EnabledStatistics::None ); - assert_eq!(props.max_statistics_size(&ColumnPath::from("a")), 50); assert_eq!( props.encoding(&ColumnPath::from("col")), @@ -1241,7 +1280,6 @@ mod tests { props.statistics_enabled(&ColumnPath::from("col")), EnabledStatistics::Chunk ); - assert_eq!(props.max_statistics_size(&ColumnPath::from("col")), 123); assert_eq!( props.bloom_filter_properties(&ColumnPath::from("col")), Some(&BloomFilterProperties { fpp: 0.1, ndv: 100 }) diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 3262d1fba704..a942481f7e4d 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -435,7 +435,7 @@ pub(crate) fn decode_page( let is_sorted = dict_header.is_sorted.unwrap_or(false); Page::DictionaryPage { buf: buffer, - num_values: dict_header.num_values as u32, + num_values: dict_header.num_values.try_into()?, encoding: Encoding::try_from(dict_header.encoding)?, is_sorted, } @@ -446,7 +446,7 @@ pub(crate) fn decode_page( .ok_or_else(|| ParquetError::General("Missing V1 data page header".to_string()))?; Page::DataPage { buf: buffer, - num_values: header.num_values as u32, + num_values: header.num_values.try_into()?, encoding: Encoding::try_from(header.encoding)?, def_level_encoding: Encoding::try_from(header.definition_level_encoding)?, rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?, @@ -460,12 +460,12 @@ pub(crate) fn decode_page( let is_compressed = header.is_compressed.unwrap_or(true); Page::DataPageV2 { buf: buffer, - num_values: header.num_values as u32, + num_values: header.num_values.try_into()?, encoding: Encoding::try_from(header.encoding)?, - num_nulls: header.num_nulls as u32, - num_rows: header.num_rows as u32, - def_levels_byte_len: header.definition_levels_byte_length as u32, - rep_levels_byte_len: header.repetition_levels_byte_length as u32, + num_nulls: header.num_nulls.try_into()?, + num_rows: header.num_rows.try_into()?, + def_levels_byte_len: header.definition_levels_byte_length.try_into()?, + rep_levels_byte_len: header.repetition_levels_byte_length.try_into()?, is_compressed, statistics: statistics::from_thrift(physical_type, header.statistics)?, } @@ -578,6 +578,27 @@ impl Iterator for SerializedPageReader { } } +fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> Result<()> { + if header_len > remaining_bytes { + return Err(eof_err!("Invalid page header")); + } + Ok(()) +} + +fn verify_page_size( + compressed_size: i32, + uncompressed_size: i32, + remaining_bytes: usize, +) -> Result<()> { + // The page's compressed size should not exceed the remaining bytes that are + // available to read. The page's uncompressed size is the expected size + // after decompression, which can never be negative. + if compressed_size < 0 || compressed_size as usize > remaining_bytes || uncompressed_size < 0 { + return Err(eof_err!("Invalid page header")); + } + Ok(()) +} + impl PageReader for SerializedPageReader { fn get_next_page(&mut self) -> Result> { loop { @@ -596,10 +617,16 @@ impl PageReader for SerializedPageReader { *header } else { let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining)?; *offset += header_len; *remaining -= header_len; header }; + verify_page_size( + header.compressed_page_size, + header.uncompressed_page_size, + *remaining, + )?; let data_len = header.compressed_page_size as usize; *offset += data_len; *remaining -= data_len; @@ -683,6 +710,7 @@ impl PageReader for SerializedPageReader { } else { let mut read = self.reader.get_read(*offset as u64)?; let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining_bytes)?; *offset += header_len; *remaining_bytes -= header_len; let page_meta = if let Ok(page_meta) = (&header).try_into() { @@ -733,12 +761,23 @@ impl PageReader for SerializedPageReader { next_page_header, } => { if let Some(buffered_header) = next_page_header.take() { + verify_page_size( + buffered_header.compressed_page_size, + buffered_header.uncompressed_page_size, + *remaining_bytes, + )?; // The next page header has already been peeked, so just advance the offset *offset += buffered_header.compressed_page_size as usize; *remaining_bytes -= buffered_header.compressed_page_size as usize; } else { let mut read = self.reader.get_read(*offset as u64)?; let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining_bytes)?; + verify_page_size( + header.compressed_page_size, + header.uncompressed_page_size, + *remaining_bytes, + )?; let data_page_size = header.compressed_page_size as usize; *offset += header_len + data_page_size; *remaining_bytes -= header_len + data_page_size; @@ -1223,8 +1262,8 @@ mod tests { let reader = SerializedFileReader::new_with_options(test_file, read_options)?; let metadata = reader.metadata(); assert_eq!(metadata.num_row_groups(), 0); - assert_eq!(metadata.column_index().unwrap().len(), 0); - assert_eq!(metadata.offset_index().unwrap().len(), 0); + assert!(metadata.column_index().is_none()); + assert!(metadata.offset_index().is_none()); // false, true predicate let test_file = get_test_file("alltypes_tiny_pages.parquet"); @@ -1236,8 +1275,8 @@ mod tests { let reader = SerializedFileReader::new_with_options(test_file, read_options)?; let metadata = reader.metadata(); assert_eq!(metadata.num_row_groups(), 0); - assert_eq!(metadata.column_index().unwrap().len(), 0); - assert_eq!(metadata.offset_index().unwrap().len(), 0); + assert!(metadata.column_index().is_none()); + assert!(metadata.offset_index().is_none()); // false, false predicate let test_file = get_test_file("alltypes_tiny_pages.parquet"); @@ -1249,8 +1288,8 @@ mod tests { let reader = SerializedFileReader::new_with_options(test_file, read_options)?; let metadata = reader.metadata(); assert_eq!(metadata.num_row_groups(), 0); - assert_eq!(metadata.column_index().unwrap().len(), 0); - assert_eq!(metadata.offset_index().unwrap().len(), 0); + assert!(metadata.column_index().is_none()); + assert!(metadata.offset_index().is_none()); Ok(()) } @@ -1340,13 +1379,15 @@ mod tests { let columns = metadata.row_group(0).columns(); let reversed: Vec<_> = columns.iter().cloned().rev().collect(); - let a = read_columns_indexes(&test_file, columns).unwrap(); - let mut b = read_columns_indexes(&test_file, &reversed).unwrap(); + let a = read_columns_indexes(&test_file, columns).unwrap().unwrap(); + let mut b = read_columns_indexes(&test_file, &reversed) + .unwrap() + .unwrap(); b.reverse(); assert_eq!(a, b); - let a = read_offset_indexes(&test_file, columns).unwrap(); - let mut b = read_offset_indexes(&test_file, &reversed).unwrap(); + let a = read_offset_indexes(&test_file, columns).unwrap().unwrap(); + let mut b = read_offset_indexes(&test_file, &reversed).unwrap().unwrap(); b.reverse(); assert_eq!(a, b); } diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index 2e05b83369cf..b7522a76f0fc 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -157,6 +157,32 @@ pub fn from_thrift( stats.max_value }; + fn check_len(min: &Option>, max: &Option>, len: usize) -> Result<()> { + if let Some(min) = min { + if min.len() < len { + return Err(ParquetError::General( + "Insufficient bytes to parse min statistic".to_string(), + )); + } + } + if let Some(max) = max { + if max.len() < len { + return Err(ParquetError::General( + "Insufficient bytes to parse max statistic".to_string(), + )); + } + } + Ok(()) + } + + match physical_type { + Type::BOOLEAN => check_len(&min, &max, 1), + Type::INT32 | Type::FLOAT => check_len(&min, &max, 4), + Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8), + Type::INT96 => check_len(&min, &max, 12), + _ => Ok(()), + }?; + // Values are encoded using PLAIN encoding definition, except that // variable-length byte arrays do not include a length prefix. // diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index b84c57a60e19..6b7707f03cd9 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -1742,6 +1742,7 @@ mod tests { let props = WriterProperties::builder() .set_statistics_enabled(EnabledStatistics::None) .set_column_statistics_enabled("a".into(), EnabledStatistics::Page) + .set_offset_index_disabled(true) // this should be ignored because of the line above .build(); let mut file = Vec::with_capacity(1024); let mut file_writer = diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index c95ce3f9223b..1b0d81c7d9ab 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -52,6 +52,11 @@ pub struct Row { #[allow(clippy::len_without_is_empty)] impl Row { + /// Constructs a `Row` from the list of `fields` and returns it. + pub fn new(fields: Vec<(String, Field)>) -> Row { + Row { fields } + } + /// Get the number of fields in this row. pub fn len(&self) -> usize { self.fields.len() @@ -283,12 +288,6 @@ impl RowAccessor for Row { row_complex_accessor!(get_map, MapInternal, Map); } -/// Constructs a `Row` from the list of `fields` and returns it. -#[inline] -pub fn make_row(fields: Vec<(String, Field)>) -> Row { - Row { fields } -} - impl fmt::Display for Row { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{{")?; @@ -1386,7 +1385,7 @@ mod tests { ("z".to_string(), Field::Float(3.1)), ("a".to_string(), Field::Str("abc".to_string())), ]; - let row = Field::Group(make_row(fields)); + let row = Field::Group(Row::new(fields)); assert_eq!(format!("{row}"), "{x: null, Y: 2, z: 3.1, a: \"abc\"}"); let row = Field::ListInternal(make_list(vec![ @@ -1431,7 +1430,7 @@ mod tests { assert!(Field::Decimal(Decimal::from_i32(4, 8, 2)).is_primitive()); // complex types - assert!(!Field::Group(make_row(vec![ + assert!(!Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ("z".to_string(), Field::Float(3.1)), @@ -1458,7 +1457,7 @@ mod tests { #[test] fn test_row_primitive_field_fmt() { // Primitives types - let row = make_row(vec![ + let row = Row::new(vec![ ("00".to_string(), Field::Null), ("01".to_string(), Field::Bool(false)), ("02".to_string(), Field::Byte(3)), @@ -1513,10 +1512,10 @@ mod tests { #[test] fn test_row_complex_field_fmt() { // Complex types - let row = make_row(vec![ + let row = Row::new(vec![ ( "00".to_string(), - Field::Group(make_row(vec![ + Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ])), @@ -1548,7 +1547,7 @@ mod tests { #[test] fn test_row_primitive_accessors() { // primitives - let row = make_row(vec![ + let row = Row::new(vec![ ("a".to_string(), Field::Null), ("b".to_string(), Field::Bool(false)), ("c".to_string(), Field::Byte(3)), @@ -1590,7 +1589,7 @@ mod tests { #[test] fn test_row_primitive_invalid_accessors() { // primitives - let row = make_row(vec![ + let row = Row::new(vec![ ("a".to_string(), Field::Null), ("b".to_string(), Field::Bool(false)), ("c".to_string(), Field::Byte(3)), @@ -1619,10 +1618,10 @@ mod tests { #[test] fn test_row_complex_accessors() { - let row = make_row(vec![ + let row = Row::new(vec![ ( "a".to_string(), - Field::Group(make_row(vec![ + Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ])), @@ -1653,10 +1652,10 @@ mod tests { #[test] fn test_row_complex_invalid_accessors() { - let row = make_row(vec![ + let row = Row::new(vec![ ( "a".to_string(), - Field::Group(make_row(vec![ + Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ])), @@ -1802,7 +1801,7 @@ mod tests { #[test] fn test_list_complex_accessors() { - let list = make_list(vec![Field::Group(make_row(vec![ + let list = make_list(vec![Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ]))]); @@ -1826,7 +1825,7 @@ mod tests { #[test] fn test_list_complex_invalid_accessors() { - let list = make_list(vec![Field::Group(make_row(vec![ + let list = make_list(vec![Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ]))]); @@ -1961,7 +1960,7 @@ mod tests { ("Y".to_string(), Field::Double(2.2)), ("Z".to_string(), Field::Str("abc".to_string())), ]; - let row = Field::Group(make_row(fields)); + let row = Field::Group(Row::new(fields)); assert_eq!( row.to_json_value(), serde_json::json!({"X": 1, "Y": 2.2, "Z": "abc"}) @@ -1990,14 +1989,14 @@ mod tests { #[cfg(test)] #[allow(clippy::many_single_char_names)] mod api_tests { - use super::{make_list, make_map, make_row}; + use super::{make_list, make_map, Row}; use crate::record::Field; #[test] fn test_field_visibility() { - let row = make_row(vec![( + let row = Row::new(vec![( "a".to_string(), - Field::Group(make_row(vec![ + Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ])), @@ -2009,7 +2008,7 @@ mod api_tests { match column.1 { Field::Group(r) => { assert_eq!( - &make_row(vec![ + &Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ]), @@ -2027,7 +2026,7 @@ mod api_tests { fn test_list_element_access() { let expected = vec![ Field::Int(1), - Field::Group(make_row(vec![ + Field::Group(Row::new(vec![ ("x".to_string(), Field::Null), ("Y".to_string(), Field::Int(2)), ])), diff --git a/parquet/src/record/reader.rs b/parquet/src/record/reader.rs index fd6ca7cdd57a..9e70f7a980db 100644 --- a/parquet/src/record/reader.rs +++ b/parquet/src/record/reader.rs @@ -24,7 +24,7 @@ use crate::basic::{ConvertedType, Repetition}; use crate::errors::{ParquetError, Result}; use crate::file::reader::{FileReader, RowGroupReader}; use crate::record::{ - api::{make_list, make_map, make_row, Field, Row}, + api::{make_list, make_map, Field, Row}, triplet::TripletIter, }; use crate::schema::types::{ColumnPath, SchemaDescPtr, SchemaDescriptor, Type, TypePtr}; @@ -217,11 +217,15 @@ impl TreeBuilder { Repetition::REPEATED, "Invalid map type: {field:?}" ); - assert_eq!( - key_value_type.get_fields().len(), - 2, - "Invalid map type: {field:?}" - ); + // Parquet spec allows no value. In that case treat as a list. #1642 + if key_value_type.get_fields().len() != 1 { + // If not a list, then there can only be 2 fields in the struct + assert_eq!( + key_value_type.get_fields().len(), + 2, + "Invalid map type: {field:?}" + ); + } path.push(String::from(key_value_type.name())); @@ -239,25 +243,35 @@ impl TreeBuilder { row_group_reader, )?; - let value_type = &key_value_type.get_fields()[1]; - let value_reader = self.reader_tree( - value_type.clone(), - path, - curr_def_level + 1, - curr_rep_level + 1, - paths, - row_group_reader, - )?; + if key_value_type.get_fields().len() == 1 { + path.pop(); + Reader::RepeatedReader( + field, + curr_def_level, + curr_rep_level, + Box::new(key_reader), + ) + } else { + let value_type = &key_value_type.get_fields()[1]; + let value_reader = self.reader_tree( + value_type.clone(), + path, + curr_def_level + 1, + curr_rep_level + 1, + paths, + row_group_reader, + )?; - path.pop(); + path.pop(); - Reader::KeyValueReader( - field, - curr_def_level, - curr_rep_level, - Box::new(key_reader), - Box::new(value_reader), - ) + Reader::KeyValueReader( + field, + curr_def_level, + curr_rep_level, + Box::new(key_reader), + Box::new(value_reader), + ) + } } // A repeated field that is neither contained by a `LIST`- or // `MAP`-annotated group nor annotated by `LIST` or `MAP` @@ -345,6 +359,19 @@ impl Reader { /// /// #backward-compatibility-rules fn is_element_type(repeated_type: &Type) -> bool { + // For legacy 2-level list types whose element type is a 2-level list + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array (LIST) { + // repeated int32 array; + // }; + // } + // + if repeated_type.is_list() || repeated_type.has_single_repeated_child() { + return false; + } + // For legacy 2-level list types with primitive element type, e.g.: // // // ARRAY (nullable list, non-null elements) @@ -399,7 +426,7 @@ impl Reader { for reader in readers { fields.push((String::from(reader.field_name()), reader.read_field()?)); } - Ok(make_row(fields)) + Ok(Row::new(fields)) } _ => panic!("Cannot call read() on {self}"), } @@ -434,7 +461,7 @@ impl Reader { fields.push((String::from(reader.field_name()), Field::Null)); } } - let row = make_row(fields); + let row = Row::new(fields); Field::Group(row) } Reader::RepeatedReader(_, def_level, rep_level, ref mut reader) => { @@ -826,7 +853,7 @@ mod tests { macro_rules! row { ($($e:tt)*) => { { - make_row(vec![$($e)*]) + Row::new(vec![$($e)*]) } } } @@ -1459,8 +1486,7 @@ mod tests { } #[test] - #[should_panic(expected = "Invalid map type")] - fn test_file_reader_rows_invalid_map_type() { + fn test_file_reader_rows_nested_map_type() { let schema = " message spark_schema { OPTIONAL group a (MAP) { @@ -1823,6 +1849,36 @@ mod tests { assert_eq!(rows, expected_rows); } + #[test] + fn test_map_no_value() { + // File schema: + // message schema { + // required group my_map (MAP) { + // repeated group key_value { + // required int32 key; + // optional int32 value; + // } + // } + // required group my_map_no_v (MAP) { + // repeated group key_value { + // required int32 key; + // } + // } + // required group my_list (LIST) { + // repeated group list { + // required int32 element; + // } + // } + // } + let rows = test_file_reader_rows("map_no_value.parquet", None).unwrap(); + + // the my_map_no_v and my_list columns should be equivalent lists by this point + for row in rows { + let cols = row.into_columns(); + assert_eq!(cols[1].1, cols[2].1); + } + } + fn test_file_reader_rows(file_name: &str, schema: Option) -> Result> { let file = get_test_file(file_name); let file_reader: Box = Box::new(SerializedFileReader::new(file)?); @@ -1839,4 +1895,21 @@ mod tests { let iter = row_group_reader.get_row_iter(schema)?; Ok(iter.map(|row| row.unwrap()).collect()) } + + #[test] + fn test_read_old_nested_list() { + let rows = test_file_reader_rows("old_list_structure.parquet", None).unwrap(); + let expected_rows = vec![row![( + "a".to_string(), + Field::ListInternal(make_list( + [ + make_list([1, 2].map(Field::Int).to_vec()), + make_list([3, 4].map(Field::Int).to_vec()) + ] + .map(Field::ListInternal) + .to_vec() + )) + ),]]; + assert_eq!(rows, expected_rows); + } } diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index b7ba95eb56bb..d9e9b22e809f 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -202,6 +202,29 @@ impl Type { self.get_basic_info().has_repetition() && self.get_basic_info().repetition() != Repetition::REQUIRED } + + /// Returns `true` if this type is annotated as a list. + pub(crate) fn is_list(&self) -> bool { + if self.is_group() { + let basic_info = self.get_basic_info(); + if let Some(logical_type) = basic_info.logical_type() { + return logical_type == LogicalType::List; + } + return basic_info.converted_type() == ConvertedType::LIST; + } + false + } + + /// Returns `true` if this type is a group with a single child field that is `repeated`. + pub(crate) fn has_single_repeated_child(&self) -> bool { + if self.is_group() { + let children = self.get_fields(); + return children.len() == 1 + && children[0].get_basic_info().has_repetition() + && children[0].get_basic_info().repetition() == Repetition::REPEATED; + } + false + } } /// A builder for primitive types. All attributes are optional @@ -533,7 +556,11 @@ impl<'a> PrimitiveTypeBuilder<'a> { } } PhysicalType::FIXED_LEN_BYTE_ARRAY => { - let max_precision = (2f64.powi(8 * self.length - 1) - 1f64).log10().floor() as i32; + let length = self + .length + .checked_mul(8) + .ok_or(general_err!("Invalid length {} for Decimal", self.length))?; + let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32; if self.precision > max_precision { return Err(general_err!( @@ -926,6 +953,32 @@ impl ColumnDescriptor { /// /// Encapsulates the file's schema ([`Type`]) and [`ColumnDescriptor`]s for /// each primitive (leaf) column. +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// use parquet::schema::types::{SchemaDescriptor, Type}; +/// use parquet::basic; // note there are two `Type`s that are different +/// // Schema for a table with two columns: "a" (int64) and "b" (int32, stored as a date) +/// let descriptor = SchemaDescriptor::new( +/// Arc::new( +/// Type::group_type_builder("my_schema") +/// .with_fields(vec![ +/// Arc::new( +/// Type::primitive_type_builder("a", basic::Type::INT64) +/// .build().unwrap() +/// ), +/// Arc::new( +/// Type::primitive_type_builder("b", basic::Type::INT32) +/// .with_converted_type(basic::ConvertedType::DATE) +/// .with_logical_type(Some(basic::LogicalType::Date)) +/// .build().unwrap() +/// ), +/// ]) +/// .build().unwrap() +/// ) +/// ); +/// ``` #[derive(PartialEq)] pub struct SchemaDescriptor { /// The top-level logical schema (the "message" type). @@ -1122,9 +1175,25 @@ pub fn from_thrift(elements: &[SchemaElement]) -> Result { )); } + if !schema_nodes[0].is_group() { + return Err(general_err!("Expected root node to be a group type")); + } + Ok(schema_nodes.remove(0)) } +/// Checks if the logical type is valid. +fn check_logical_type(logical_type: &Option) -> Result<()> { + if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type { + if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 { + return Err(general_err!( + "Bit width must be 8, 16, 32, or 64 for Integer logical type" + )); + } + } + Ok(()) +} + /// Constructs a new Type from the `elements`, starting at index `index`. /// The first result is the starting index for the next Type after this one. If it is /// equal to `elements.len()`, then this Type is the last one. @@ -1149,6 +1218,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize .logical_type .as_ref() .map(|value| LogicalType::from(value.clone())); + + check_logical_type(&logical_type)?; + let field_id = elements[index].field_id; match elements[index].num_children { // From parquet-format: diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs index ceb6b1c29fe8..b216fec6f3e7 100644 --- a/parquet/src/thrift.rs +++ b/parquet/src/thrift.rs @@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> { let mut shift = 0; loop { let byte = self.read_byte()?; - in_progress |= ((byte & 0x7F) as u64) << shift; + in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift); shift += 7; if byte & 0x80 == 0 { return Ok(in_progress); @@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> { } } +macro_rules! thrift_unimplemented { + () => { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::NotImplemented, + message: "not implemented".to_string(), + })) + }; +} + impl TInputProtocol for TCompactSliceInputProtocol<'_> { fn read_message_begin(&mut self) -> thrift::Result { unimplemented!() } fn read_message_end(&mut self) -> thrift::Result<()> { - unimplemented!() + thrift_unimplemented!() } fn read_struct_begin(&mut self) -> thrift::Result> { @@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> { ), _ => { if field_delta != 0 { - self.last_read_field_id += field_delta as i16; + self.last_read_field_id = self + .last_read_field_id + .checked_add(field_delta as i16) + .map_or_else( + || { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!( + "cannot add {} to {}", + field_delta, self.last_read_field_id + ), + })) + }, + Ok, + )?; } else { self.last_read_field_id = self.read_i16()?; }; @@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> { } fn read_set_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_set_end(&mut self) -> thrift::Result<()> { - unimplemented!() + thrift_unimplemented!() } fn read_map_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_map_end(&mut self) -> thrift::Result<()> { diff --git a/parquet/tests/arrow_reader/bad_data.rs b/parquet/tests/arrow_reader/bad_data.rs index 74342031432a..cfd61e82d32b 100644 --- a/parquet/tests/arrow_reader/bad_data.rs +++ b/parquet/tests/arrow_reader/bad_data.rs @@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() { let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err(); assert_eq!( err.to_string(), - "External: Parquet argument error: EOF: eof decoding byte array" + "External: Parquet argument error: Parquet error: Integer overflow: out of range integral type conversion attempted" ); } diff --git a/parquet_derive/LICENSE.txt b/parquet_derive/LICENSE.txt new file mode 120000 index 000000000000..4ab43736a839 --- /dev/null +++ b/parquet_derive/LICENSE.txt @@ -0,0 +1 @@ +../LICENSE.txt \ No newline at end of file diff --git a/parquet_derive/NOTICE.txt b/parquet_derive/NOTICE.txt new file mode 120000 index 000000000000..eb9f24e040b5 --- /dev/null +++ b/parquet_derive/NOTICE.txt @@ -0,0 +1 @@ +../NOTICE.txt \ No newline at end of file